commit 4114c2e0df5b1f6781efc54e2aa1ed8c4af8d031 Author: Lu Xinlong Date: Wed Aug 20 14:29:42 2025 +0800 initial commit diff --git a/Dockerfile.funasr-mr100 b/Dockerfile.funasr-mr100 new file mode 100644 index 0000000..11fd115 --- /dev/null +++ b/Dockerfile.funasr-mr100 @@ -0,0 +1,20 @@ +FROM corex:4.3.0 + +WORKDIR /root + +COPY requirements.txt /root +RUN pip install -r requirements.txt + +RUN apt update && apt install -y vim net-tools + +RUN pip install funasr==1.2.6 openai-whisper + +ADD . /root/ +ADD nltk_data.tar.gz /root/ +RUN tar -xvzf nltk_data.tar.gz + +RUN cp ./replaced_files/mr_v100/cif_predictor.py /usr/local/lib/python3.10/site-packages/funasr/models/paraformer/ + +EXPOSE 80 +ENTRYPOINT ["bash"] +CMD ["./start_funasr.sh"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ef663dd --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# 天数智芯 智铠100 FunASR + +## 镜像构造 +```shell +docker build -f ./Dockerfile.funasr-mr100 -t . +``` + +## 使用说明 +### 快速镜像测试 +对funasr的测试需要在以上构造好的镜像容器内测试,测试步骤 +1. 将需要测试的音频wav文件和相应的ground truth文件(含有音频的正确内容文字的文本文件)放置于当前文件夹,并且准备好相应的ASR模型路径 +2. 快速测试命令 +```shell +docker run -it \ + -v /usr/src:/usr/src \ + -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \ + -v $PWD:/tmp/workspace \ + -v : \ + -e MODEL_DIR= \ + -e TEST_FILE= \ + -e ANSWER_FILE= \ + --cpus=4 --memory=16g \ + +``` + +### 定制化手动运行 + +用户可使用类似上述的docker run指令以交互形式进入镜像中,主要的测试代码为`test_funasr.py`,用户可自行修改代码中需要测试的模型路径、测试文件路径以及调用funASR逻辑 + +## 智铠100模型适配情况 +我们在智铠100上针对funASR部分进行了所有大类的适配,测试方式为在Nvidia A100环境下和智铠100加速卡上对同一段长音频进行语音识别任务,获取运行时间,1-cer指标。运行时都只使用一张显卡 + +| 模型大类 | 模型地址 |A100运行时间(秒)|智铠100运行时间(秒)|A100 1-cer|智铠100 1-cer| 备注 | +|------|---------------|-----|----|-------|-------|---------------------| +| sense_voice | https://www.modelscope.cn/models/iic/SenseVoiceSmall | 1.8327 | 1.2579 | 0.980033 | 0.980033 | | +| whisper | https://www.modelscope.cn/models/iic/Whisper-large-v3 | 23.8337 | 22.9085 | 0.910150 | 0.910150 | | +| paraformer | https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch | 4.7246 | 4.7719 | 0.955075 | 0.955075 | | +| conformer | https://www.modelscope.cn/models/iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch | 95.9631 | 125.8649 | 0.349418 | 0.346090 | | +| uni_asr | https://www.modelscope.cn/models/iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline | 70.5289 | 88.9481 | 0.717138 | 0.717138 | 该部分的适配修改了一些funASR源码 | \ No newline at end of file diff --git a/download_nltk_model.py b/download_nltk_model.py new file mode 100644 index 0000000..10910f9 --- /dev/null +++ b/download_nltk_model.py @@ -0,0 +1,4 @@ +import nltk +nltk.download('punkt') +nltk.download('wordnet') +nltk.download('omw-1.4') \ No newline at end of file diff --git a/nltk_data.tar.gz b/nltk_data.tar.gz new file mode 100644 index 0000000..68a9b4c Binary files /dev/null and b/nltk_data.tar.gz differ diff --git a/replaced_files/mr_v100/cif_predictor.py b/replaced_files/mr_v100/cif_predictor.py new file mode 100644 index 0000000..9b19ba9 --- /dev/null +++ b/replaced_files/mr_v100/cif_predictor.py @@ -0,0 +1,762 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import torch +import logging +import numpy as np + +from funasr.register import tables +from funasr.train_utils.device_funcs import to_device +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from torch.cuda.amp import autocast + + +@tables.register("predictor_classes", "CifPredictor") +class CifPredictor(torch.nn.Module): + def __init__( + self, + idim, + l_order, + r_order, + threshold=1.0, + dropout=0.1, + smooth_factor=1.0, + noise_threshold=0, + tail_threshold=0.45, + ): + super().__init__() + + self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0) + self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim) + self.cif_output = torch.nn.Linear(idim, 1) + self.dropout = torch.nn.Dropout(p=dropout) + self.threshold = threshold + self.smooth_factor = smooth_factor + self.noise_threshold = noise_threshold + self.tail_threshold = tail_threshold + + def forward( + self, + hidden, + target_label=None, + mask=None, + ignore_id=-1, + mask_chunk_predictor=None, + target_label_length=None, + ): + + with autocast(False): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + memory = self.cif_conv1d(queries) + output = memory + context + output = self.dropout(output) + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + if mask is not None: + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + mask = mask.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + else: + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) + elif self.tail_threshold > 0.0: + hidden, alphas, token_num = self.tail_process_fn( + hidden, alphas, token_num, mask=mask + ) + + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + + if target_length is None and self.tail_threshold > 0.0: + token_num_int = torch.max(token_num).type(torch.int32).item() + acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + + return acoustic_embeds, token_num, alphas, cif_peak + + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): + b, t, d = hidden.size() + tail_threshold = self.tail_threshold + if mask is not None: + zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device) + ones_t = torch.ones_like(zeros_t) + mask_1 = torch.cat([mask, zeros_t], dim=1) + mask_2 = torch.cat([ones_t, mask], dim=1) + mask = mask_2 - mask_1 + tail_threshold = mask * tail_threshold + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) + else: + tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) + tail_threshold = torch.reshape(tail_threshold, (1, 1)) + alphas = torch.cat([alphas, tail_threshold], dim=1) + zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) + hidden = torch.cat([hidden, zeros], dim=1) + token_num = alphas.sum(dim=-1) + token_num_floor = torch.floor(token_num) + + return hidden, alphas, token_num_floor + + def gen_frame_alignments( + self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None + ): + batch_size, maximum_length = alphas.size() + int_type = torch.int32 + + is_training = self.training + if is_training: + token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) + else: + token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type) + + max_token_num = torch.max(token_num).item() + + alphas_cumsum = torch.cumsum(alphas, dim=1) + alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) + alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1) + + index = torch.ones([batch_size, max_token_num], dtype=int_type) + index = torch.cumsum(index, dim=1) + index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device) + + index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type) + index_div_bool_zeros = index_div.eq(0) + index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1 + index_div_bool_zeros_count = torch.clamp( + index_div_bool_zeros_count, 0, encoder_sequence_length.max() + ) + token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device) + index_div_bool_zeros_count *= token_num_mask + + index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat( + 1, 1, maximum_length + ) + ones = torch.ones_like(index_div_bool_zeros_count_tile) + zeros = torch.zeros_like(index_div_bool_zeros_count_tile) + ones = torch.cumsum(ones, dim=2) + cond = index_div_bool_zeros_count_tile == ones + index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) + + index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool) + index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type) + index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type) + predictor_mask = ( + (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())) + .type(int_type) + .to(encoder_sequence_length.device) + ) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask + + predictor_alignments = index_div_bool_zeros_count_tile_out + predictor_alignments_length = predictor_alignments.sum(-1).type( + encoder_sequence_length.dtype + ) + return predictor_alignments.detach(), predictor_alignments_length.detach() + + +@tables.register("predictor_classes", "CifPredictorV2") +class CifPredictorV2(torch.nn.Module): + def __init__( + self, + idim, + l_order, + r_order, + threshold=1.0, + dropout=0.1, + smooth_factor=1.0, + noise_threshold=0, + tail_threshold=0.0, + tf2torch_tensor_name_prefix_torch="predictor", + tf2torch_tensor_name_prefix_tf="seq2seq/cif", + tail_mask=True, + ): + super().__init__() + + self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0) + self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1) + self.cif_output = torch.nn.Linear(idim, 1) + self.dropout = torch.nn.Dropout(p=dropout) + self.threshold = threshold + self.smooth_factor = smooth_factor + self.noise_threshold = noise_threshold + self.tail_threshold = tail_threshold + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + self.tail_mask = tail_mask + + def forward( + self, + hidden, + target_label=None, + mask=None, + ignore_id=-1, + mask_chunk_predictor=None, + target_label_length=None, + ): + + with autocast(False): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + if mask is not None: + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + mask = mask.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length.squeeze(-1) + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + else: + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) + elif self.tail_threshold > 0.0: + if self.tail_mask: + hidden, alphas, token_num = self.tail_process_fn( + hidden, alphas, token_num, mask=mask + ) + else: + hidden, alphas, token_num = self.tail_process_fn( + hidden, alphas, token_num, mask=None + ) + + acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) + if target_length is None and self.tail_threshold > 0.0: + token_num_int = torch.max(token_num).type(torch.int32).item() + acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + + return acoustic_embeds, token_num, alphas, cif_peak + + def forward_chunk(self, hidden, cache=None, **kwargs): + is_final = kwargs.get("is_final", False) + batch_size, len_time, hidden_size = hidden.shape + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + + alphas = alphas.squeeze(-1) + + token_length = [] + list_fires = [] + list_frames = [] + cache_alphas = [] + cache_hiddens = [] + + if cache is not None and "chunk_size" in cache: + alphas[:, : cache["chunk_size"][0]] = 0.0 + if not is_final: + alphas[:, sum(cache["chunk_size"][:2]) :] = 0.0 + if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: + cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) + cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) + hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) + alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) + if cache is not None and is_final: + tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) + tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) + tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) + hidden = torch.cat((hidden, tail_hidden), dim=1) + alphas = torch.cat((alphas, tail_alphas), dim=1) + + len_time = alphas.shape[1] + for b in range(batch_size): + integrate = 0.0 + frames = torch.zeros((hidden_size), device=hidden.device) + list_frame = [] + list_fire = [] + for t in range(len_time): + alpha = alphas[b][t] + if alpha + integrate < self.threshold: + integrate += alpha + list_fire.append(integrate) + frames += alpha * hidden[b][t] + else: + frames += (self.threshold - integrate) * hidden[b][t] + list_frame.append(frames) + integrate += alpha + list_fire.append(integrate) + integrate -= self.threshold + frames = integrate * hidden[b][t] + + cache_alphas.append(integrate) + if integrate > 0.0: + cache_hiddens.append(frames / integrate) + else: + cache_hiddens.append(frames) + + token_length.append(torch.tensor(len(list_frame), device=alphas.device)) + list_fires.append(list_fire) + list_frames.append(list_frame) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + + max_token_len = max(token_length) + if max_token_len == 0: + return hidden, torch.stack(token_length, 0), None, None + list_ls = [] + for b in range(batch_size): + pad_frames = torch.zeros( + (max_token_len - token_length[b], hidden_size), device=alphas.device + ) + if token_length[b] == 0: + list_ls.append(pad_frames) + else: + list_frames[b] = torch.stack(list_frames[b]) + list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0)) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None + + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): + b, t, d = hidden.size() + tail_threshold = self.tail_threshold + if mask is not None: + zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device) + ones_t = torch.ones_like(zeros_t) + mask_1 = torch.cat([mask, zeros_t], dim=1) + mask_2 = torch.cat([ones_t, mask], dim=1) + mask = mask_2 - mask_1 + tail_threshold = mask * tail_threshold + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) + else: + tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) + tail_threshold = torch.reshape(tail_threshold, (1, 1)) + if b > 1: + alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1) + else: + alphas = torch.cat([alphas, tail_threshold], dim=1) + zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) + hidden = torch.cat([hidden, zeros], dim=1) + token_num = alphas.sum(dim=-1) + token_num_floor = torch.floor(token_num) + + return hidden, alphas, token_num_floor + + def gen_frame_alignments( + self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None + ): + batch_size, maximum_length = alphas.size() + int_type = torch.int32 + + is_training = self.training + if is_training: + token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) + else: + token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type) + + max_token_num = torch.max(token_num).item() + + alphas_cumsum = torch.cumsum(alphas, dim=1) + alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) + alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1) + + index = torch.ones([batch_size, max_token_num], dtype=int_type) + index = torch.cumsum(index, dim=1) + index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device) + + index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type) + index_div_bool_zeros = index_div.eq(0) + index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1 + index_div_bool_zeros_count = torch.clamp( + index_div_bool_zeros_count, 0, encoder_sequence_length.max() + ) + token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device) + index_div_bool_zeros_count *= token_num_mask + + index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat( + 1, 1, maximum_length + ) + ones = torch.ones_like(index_div_bool_zeros_count_tile) + zeros = torch.zeros_like(index_div_bool_zeros_count_tile) + ones = torch.cumsum(ones, dim=2) + cond = index_div_bool_zeros_count_tile == ones + index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) + + index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool) + index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type) + index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type) + predictor_mask = ( + (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())) + .type(int_type) + .to(encoder_sequence_length.device) + ) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask + + predictor_alignments = index_div_bool_zeros_count_tile_out + predictor_alignments_length = predictor_alignments.sum(-1).type( + encoder_sequence_length.dtype + ) + return predictor_alignments.detach(), predictor_alignments_length.detach() + + +@tables.register("predictor_classes", "CifPredictorV2Export") +class CifPredictorV2Export(torch.nn.Module): + def __init__(self, model, **kwargs): + super().__init__() + + self.pad = model.pad + self.cif_conv1d = model.cif_conv1d + self.cif_output = model.cif_output + self.threshold = model.threshold + self.smooth_factor = model.smooth_factor + self.noise_threshold = model.noise_threshold + self.tail_threshold = model.tail_threshold + + def forward( + self, + hidden: torch.Tensor, + mask: torch.Tensor, + ): + alphas, token_num = self.forward_cnn(hidden, mask) + mask = mask.transpose(-1, -2).float() + mask = mask.squeeze(-1) + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask) + acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold) + + return acoustic_embeds, token_num, alphas, cif_peak + + def forward_cnn( + self, + hidden: torch.Tensor, + mask: torch.Tensor, + ): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + alphas = alphas.squeeze(-1) + token_num = alphas.sum(-1) + + return alphas, token_num + + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): + b, t, d = hidden.size() + tail_threshold = self.tail_threshold + + zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device) + ones_t = torch.ones_like(zeros_t) + + mask_1 = torch.cat([mask, zeros_t], dim=1) + mask_2 = torch.cat([ones_t, mask], dim=1) + mask = mask_2 - mask_1 + tail_threshold = mask * tail_threshold + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) + + zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) + hidden = torch.cat([hidden, zeros], dim=1) + token_num = alphas.sum(dim=-1) + token_num_floor = torch.floor(token_num) + + return hidden, alphas, token_num_floor + + +@torch.jit.script +def cif_v1_export(hidden, alphas, threshold: float): + device = hidden.device + dtype = hidden.dtype + batch_size, len_time, hidden_size = hidden.size() + threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) + + frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) + fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) + + # prefix_sum = torch.cumsum(alphas, dim=1) + prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to( + torch.float32 + ) # cumsum precision degradation cause wrong result in extreme + prefix_sum_floor = torch.floor(prefix_sum) + dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) + dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) + + dislocation_prefix_sum_floor[:, 0] = 0 + dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor + + fire_idxs = dislocation_diff > 0 + fires[fire_idxs] = 1 + fires = fires + prefix_sum - prefix_sum_floor + + # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) + prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1) + frames = prefix_sum_hidden[fire_idxs] + shift_frames = torch.roll(frames, 1, dims=0) + + batch_len = fire_idxs.sum(1) + batch_idxs = torch.cumsum(batch_len, dim=0) + shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0) + shift_batch_idxs[0] = 0 + shift_frames[shift_batch_idxs] = 0 + + remains = fires - torch.floor(fires) + # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] + remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs] + + shift_remain_frames = torch.roll(remain_frames, 1, dims=0) + shift_remain_frames[shift_batch_idxs] = 0 + + frames = frames - shift_frames + shift_remain_frames - remain_frames + + # max_label_len = batch_len.max() + max_label_len = alphas.sum(dim=-1) + max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64) + + # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) + frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) + indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) + frame_fires_idxs = indices < batch_len.unsqueeze(1) + frame_fires[frame_fires_idxs] = frames + return frame_fires, fires + + +@torch.jit.script +def cif_export(hidden, alphas, threshold: float): + batch_size, len_time, hidden_size = hidden.size() + threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) + + # loop varss + integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) + # intermediate vars along time + list_fires = [] + list_frames = [] + + for t in range(len_time): + alpha = alphas[:, t] + distribution_completion = ( + torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate + ) + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, + integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), + integrate, + ) + cur = torch.where(fire_place, distribution_completion, alpha) + remainds = alpha - cur + + frame += cur[:, None] * hidden[:, t, :] + list_frames.append(frame) + frame = torch.where( + fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame + ) + + fires = torch.stack(list_fires, 1) + frames = torch.stack(list_frames, 1) + + fire_idxs = fires >= threshold + frame_fires = torch.zeros_like(hidden) + max_label_len = frames[0, fire_idxs[0]].size(0) + for b in range(batch_size): + frame_fire = frames[b, fire_idxs[b]] + frame_len = frame_fire.size(0) + frame_fires[b, :frame_len, :] = frame_fire + + if frame_len >= max_label_len: + max_label_len = frame_len + frame_fires = frame_fires[:, :max_label_len, :] + return frame_fires, fires + + +class mae_loss(torch.nn.Module): + + def __init__(self, normalize_length=False): + super(mae_loss, self).__init__() + self.normalize_length = normalize_length + self.criterion = torch.nn.L1Loss(reduction="sum") + + def forward(self, token_length, pre_token_length): + loss_token_normalizer = token_length.size(0) + if self.normalize_length: + loss_token_normalizer = token_length.sum().type(torch.float32) + loss = self.criterion(token_length, pre_token_length) + loss = loss / loss_token_normalizer + return loss + + +def cif(hidden, alphas, threshold): + batch_size, len_time, hidden_size = hidden.size() + + # loop varss + integrate = torch.zeros([batch_size], device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], device=hidden.device) + # intermediate vars along time + list_fires = [] + list_frames = [] + + for t in range(len_time): + alpha = alphas[:, t] + distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate + ) + cur = torch.where(fire_place, distribution_completion, alpha) + remainds = alpha - cur + + frame += cur[:, None] * hidden[:, t, :] + list_frames.append(frame) + frame = torch.where( + fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame + ) + + fires = torch.stack(list_fires, 1) + frames = torch.stack(list_frames, 1) + list_ls = [] + len_labels = torch.round(alphas.sum(-1)).int() + max_label_len = len_labels.max() + for b in range(batch_size): + fire = fires[b, :] + l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) + pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device) + list_ls.append(torch.cat([l, pad_l], 0)) + return torch.stack(list_ls, 0), fires + + +def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False): + batch_size, len_time = alphas.size() + device = alphas.device + dtype = alphas.dtype + + threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) + + fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) + + if torch.cuda.get_device_name() == "Iluvatar MR-V100": + prefix_sum = torch.cumsum(alphas, dim=1) + else: + prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to( + torch.float32 + ) # cumsum precision degradation cause wrong result in extreme + prefix_sum_floor = torch.floor(prefix_sum) + dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) + dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) + + dislocation_prefix_sum_floor[:, 0] = 0 + dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor + + fire_idxs = dislocation_diff > 0 + fires[fire_idxs] = 1 + fires = fires + prefix_sum - prefix_sum_floor + if return_fire_idxs: + return fires, fire_idxs + return fires + + +def cif_v1(hidden, alphas, threshold): + fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True) + + device = hidden.device + dtype = hidden.dtype + batch_size, len_time, hidden_size = hidden.size() + # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) + # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) + frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) + prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1) + + frames = prefix_sum_hidden[fire_idxs] + shift_frames = torch.roll(frames, 1, dims=0) + + batch_len = fire_idxs.sum(1) + batch_idxs = torch.cumsum(batch_len, dim=0) + shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0) + shift_batch_idxs[0] = 0 + shift_frames[shift_batch_idxs] = 0 + + remains = fires - torch.floor(fires) + # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] + remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs] + + shift_remain_frames = torch.roll(remain_frames, 1, dims=0) + shift_remain_frames[shift_batch_idxs] = 0 + + frames = frames - shift_frames + shift_remain_frames - remain_frames + + # max_label_len = batch_len.max() + max_label_len = ( + torch.round(alphas.sum(-1)).int().max() + ) # torch.round to calculate the max length + + # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) + frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) + indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) + frame_fires_idxs = indices < batch_len.unsqueeze(1) + frame_fires[frame_fires_idxs] = frames + return frame_fires, fires + + +def cif_wo_hidden(alphas, threshold): + batch_size, len_time = alphas.size() + + # loop varss + integrate = torch.zeros([batch_size], device=alphas.device) + # intermediate vars along time + list_fires = [] + + for t in range(len_time): + alpha = alphas[:, t] + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, + integrate - torch.ones([batch_size], device=alphas.device) * threshold, + integrate, + ) + + fires = torch.stack(list_fires, 1) + return fires diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0d2f9d2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +requests +wheel +websocket-client +pydantic<2.0.0 +numpy<2.0 +PYYaml +Levenshtein +ruamel.yaml +nltk==3.7 +pynini==2.1.6 +soundfile \ No newline at end of file diff --git a/start_funasr.sh b/start_funasr.sh new file mode 100755 index 0000000..53381e4 --- /dev/null +++ b/start_funasr.sh @@ -0,0 +1,3 @@ +unset CUDA_VISIBLE_DEVICES +unset NVIDIA_VISIBLE_DEVICES +python3 ./test_funasr.py \ No newline at end of file diff --git a/test_funasr.py b/test_funasr.py new file mode 100644 index 0000000..4228e47 --- /dev/null +++ b/test_funasr.py @@ -0,0 +1,180 @@ +import os +import time +import torchaudio +import torch +from funasr import AutoModel +from funasr.utils.postprocess_utils import rich_transcription_postprocess +from utils.calculate import cal_per_cer +import json + +def split_audio(waveform, sample_rate, segment_seconds=20): + segment_samples = segment_seconds * sample_rate + segments = [] + for i in range(0, waveform.shape[1], segment_samples): + segment = waveform[:, i:i + segment_samples] + if segment.shape[1] > 0: + segments.append(segment) + return segments + +def determine_model_type(model_name): + if "sensevoice" in model_name.lower(): + return "sense_voice" + elif "whisper" in model_name.lower(): + return "whisper" + elif "paraformer" in model_name.lower(): + return "paraformer" + elif "conformer" in model_name.lower(): + return "conformer" + elif "uniasr" in model_name.lower(): + return "uni_asr" + else: + return "unknown" + +def test_funasr(model_dir, audio_file, answer_file, use_gpu): + model_name = os.path.basename(model_dir) + model_type = determine_model_type(model_name) + + if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper": + # 天垓100情况下的Whisper需要绕过不支持算子 + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + + # 不使用VAD, punct,spk模型,就测试原始ASR能力 + model = AutoModel( + model=model_dir, + # vad_model="fsmn-vad", + # vad_kwargs={"max_single_segment_time": 30000}, + vad_model=None, + device="cuda:0" if use_gpu else "cpu", + disable_update=True + ) + + waveform, sample_rate = torchaudio.load(audio_file) + # print(waveform.shape) + duration = waveform.shape[1] / sample_rate + segments = split_audio(waveform, sample_rate, segment_seconds=20) + + generated_text = "" + processing_time = 0 + + if model_type == "uni_asr": + # uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错 + # 因为可能会被切掉所有音频导致实际编解码输入为0 + ts1 = time.time() + res = model.generate( + input=audio_file + ) + generated_text = res[0]["text"] + ts2 = time.time() + processing_time = ts2 - ts1 + else: + # 按照切分的音频依次输入 + for i, segment in enumerate(segments): + segment_path = f"temp_seg_{i}.wav" + torchaudio.save(segment_path, segment, sample_rate) + ts1 = time.time() + if model_type == "sense_voice": + res = model.generate( + input=segment_path, + cache={}, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=True, + batch_size_s=60, + merge_vad=False, + # merge_length_s=15, + ) + text = rich_transcription_postprocess(res[0]["text"]) + elif model_type == "whisper": + DecodingOptions = { + "task": "transcribe", + "language": "zh", + "beam_size": None, + "fp16": False, + "without_timestamps": False, + "prompt": None, + } + res = model.generate( + DecodingOptions=DecodingOptions, + input=segment_path, + batch_size_s=0, + ) + text = res[0]["text"] + elif model_type == "paraformer": + res = model.generate( + input=segment_path, + batch_size_s=300 + ) + text = res[0]["text"] + # paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果 + text = text.replace(" ", "") + elif model_type == "conformer": + res = model.generate( + input=segment_path, + batch_size_s=300 + ) + text = res[0]["text"] + # elif model_type == "uni_asr": + # if i == 0: + # os.remove(segment_path) + # continue + # res = model.generate( + # input=segment_path + # ) + # text = res[0]["text"] + else: + raise RuntimeError("unknown model type") + ts2 = time.time() + generated_text += text + processing_time += (ts2 - ts1) + os.remove(segment_path) + + rtf = processing_time / duration + print("Text:", generated_text, flush=True) + print(f"Audio duration:\t{duration:.3f} s", flush=True) + print(f"Elapsed:\t{processing_time:.3f} s", flush=True) + print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True) + with open(answer_file, 'r', encoding='utf-8') as f: + groundtruth_text = f.read() + acc = cal_per_cer(generated_text, groundtruth_text, "zh") + print(f"1-cer = {acc}", flush=True) + + return processing_time, acc, generated_text + +if __name__ == "__main__": + test_result = { + "time_cuda": 0, + "acc_cuda": 0, + "text_cuda": "", + "success": False + } + + try: + model_dict = { + "sense_voice": "SenseVoiceSmall", + "paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + "conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch", + "whisper": "Whisper-large-v3", + "uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline" + } + LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true" + K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true" + workspace_path = "../" if LOCAL_TEST else "/tmp/workspace" + model_dir = os.path.join("/model", model_dict["sense_voice"]) if LOCAL_TEST else os.environ["MODEL_DIR"] + audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"]) + answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"]) + result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"]) + # test_funasr(model_dir, audio_file, answer_file, False) + processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True) + test_result["time_cuda"] = processing_time + test_result["acc_cuda"] = acc + test_result["text_cuda"] = generated_text + test_result["success"] = True + except Exception as e: + print(f"ASR测试出错: {e}", flush=True) + with open(result_file, "w", encoding="utf-8") as fp: + json.dump(test_result, fp, ensure_ascii=False, indent=4) + # 如果是SUT起来镜像的话,需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要 + if K8S_TEST: + print(f"Start to sleep indefinitely", flush=True) + time.sleep(100000) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/calculate.py b/utils/calculate.py new file mode 100644 index 0000000..3e8820a --- /dev/null +++ b/utils/calculate.py @@ -0,0 +1,834 @@ +import re +import time + +import Levenshtein +from utils.tokenizer import Tokenizer +from typing import List, Tuple +from utils.model import SegmentModel +from utils.model import AudioItem +from utils.reader import read_data +from utils.logger import logger +from utils.model import VoiceSegment +from utils.model import WordModel +from difflib import SequenceMatcher +import logging + + +def calculate_punctuation_ratio(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float: + """ + 计算acc + :param datas: + :return: + """ + total_standard_punctuation = 0 + total_gen_punctuation = 0 + for answer, results in datas: + # 计算 1-cer。 + # 计算标点符号比例。 + # 将所有的text组合起来与标答计算 1-cer + standard_text = "" + for item in answer.voice: + standard_text = standard_text + item.answer + gen_text = "" + for item in results: + gen_text = gen_text + item.text + + total_standard_punctuation = total_standard_punctuation + count_punctuation(standard_text) + total_gen_punctuation = total_gen_punctuation + count_punctuation(gen_text) + + punctuation_ratio = total_gen_punctuation / total_standard_punctuation + return punctuation_ratio + + +def calculate_acc(datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str) -> float: + """ + 计算acc + :param datas: + :return: + """ + total_acc = 0 + for answer, results in datas: + # 计算 1-cer。 + # 计算标点符号比例。 + # 将所有的text组合起来与标答计算 1-cer + standard_text = "" + for item in answer.voice: + standard_text = standard_text + item.answer + gen_text = "" + for item in results: + gen_text = gen_text + item.text + acc = cal_per_cer(gen_text, standard_text, language) + total_acc = total_acc + acc + acc = total_acc / len(datas) + return acc + + +def get_alignment_type(language: str): + chart_langs = ["zh", "ja", "ko", "th", "lo", "my", "km", "bo"] # 中文、日文、韩文、泰语、老挝语、缅甸语、高棉语、藏语 + if language in chart_langs: + return "chart" + else: + return "word" + + +def cal_per_cer(text: str, answer: str, language: str): + if not answer: + return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1 + + text = remove_punctuation(text) + answer = remove_punctuation(answer) + + text_chars = Tokenizer.norm_and_tokenize([text], language)[0] + answer_chars = Tokenizer.norm_and_tokenize([answer], language)[0] + + # 如果答案为空,返回默认准确率 + if not answer_chars: + return 0.0 # 或者 1.0,取决于你的设计需求 + + alignment_type = get_alignment_type(language) + if alignment_type == "chart": + text_chars = list(text) + answer_chars = list(answer) + ops = Levenshtein.editops(text_chars, answer_chars) + insert = len(list(filter(lambda x: x[0] == "insert", ops))) + delete = len(list(filter(lambda x: x[0] == "delete", ops))) + replace = len(list(filter(lambda x: x[0] == "replace", ops))) + else: + matcher = SequenceMatcher(None, text_chars, answer_chars) + + insert = 0 + delete = 0 + replace = 0 + + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == 'replace': + replace += max(i2 - i1, j2 - j1) + elif tag == 'delete': + delete += (i2 - i1) + elif tag == 'insert': + insert += (j2 - j1) + + cer = (insert + delete + replace) / len(answer_chars) + acc = 1 - cer + return acc + + +def cal_total_cer(samples: list): + """ + samples: List of tuples [(pred_text, ref_text), ...] + """ + total_insert = 0 + total_delete = 0 + total_replace = 0 + total_ref_len = 0 + + for text, answer in samples: + + if not answer: + return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1 + + text = remove_punctuation(text) + answer = remove_punctuation(answer) + + text_chars = list(text) + answer_chars = list(answer) + + ops = Levenshtein.editops(text_chars, answer_chars) + insert = len(list(filter(lambda x: x[0] == "insert", ops))) + delete = len(list(filter(lambda x: x[0] == "delete", ops))) + replace = len(list(filter(lambda x: x[0] == "replace", ops))) + + total_insert += insert + total_delete += delete + total_replace += replace + total_ref_len += len(answer_chars) + + total_cer = (total_insert + total_delete + total_replace) / total_ref_len if total_ref_len > 0 else 0.0 + total_acc = 1 - total_cer + return total_acc + + +def remove_punctuation(text: str) -> str: + # 去除中英文标点 + return re.sub(r'[^\w\s\u4e00-\u9fff]', '', text) + + +def count_punctuation(text: str) -> int: + """统计文本中的指定标点个数""" + return len(re.findall(r"[^\w\s\u4e00-\u9fa5]", text)) + + +from typing import List, Optional, Tuple + + +def calculate_standard_sentence_delay(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float: + for audio_item, asr_results in datas: + if not audio_item.voice: + continue # 没有标答内容 + # + audio_texts = [] + asr_texts = [] + + ref = audio_item.voice[0] # 默认取第一个标答段 + ref_end_ms = int(ref.end * 1000) + + # 找出所有ASR中包含标答尾字的文本(简化为包含标答最后一个字) + target_char = ref.answer.strip()[-1] # 标答尾字 + matching_results = [r for r in asr_results if target_char in r.text and r.words] + + if not matching_results: + continue # 没有找到包含尾字的ASR段 + + # 找出这些ASR段中最后一个词的end_time,最大值作为尾字时间 + latest_word_time = max(word.end_time for r in matching_results for word in r.words) + + delay = latest_word_time - ref_end_ms + + print(audio_item) + print(asr_results) + return 0 + + +def align_texts(ref_text: str, hyp_text: str) -> List[Tuple[Optional[int], Optional[int]]]: + """ + 使用编辑距离计算两个字符串的字符对齐 + 返回:[(ref_idx, hyp_idx), ...] + """ + ops = Levenshtein.editops(ref_text, hyp_text) + ref_len = len(ref_text) + hyp_len = len(hyp_text) + ref_idx = 0 + hyp_idx = 0 + alignment = [] + + for op, i, j in ops: + while ref_idx < i and hyp_idx < j: + alignment.append((ref_idx, hyp_idx)) + ref_idx += 1 + hyp_idx += 1 + if op == "replace": + alignment.append((i, j)) + ref_idx = i + 1 + hyp_idx = j + 1 + elif op == "delete": + alignment.append((i, None)) + ref_idx = i + 1 + elif op == "insert": + alignment.append((None, j)) + hyp_idx = j + 1 + + while ref_idx < ref_len and hyp_idx < hyp_len: + alignment.append((ref_idx, hyp_idx)) + ref_idx += 1 + hyp_idx += 1 + while ref_idx < ref_len: + alignment.append((ref_idx, None)) + ref_idx += 1 + while hyp_idx < hyp_len: + alignment.append((None, hyp_idx)) + hyp_idx += 1 + + return alignment + + +def align_tokens(ref_text: List[str], hyp_text: List[str]) -> List[Tuple[Optional[int], Optional[int]]]: + """ + 计算分词后的两个字符串的对齐 + 返回:[(ref_idx, hyp_idx), ...] + """ + matcher = SequenceMatcher(None, ref_text, hyp_text) + alignment = [] + + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == 'equal' or tag == 'replace': + for r, h in zip(range(i1, i2), range(j1, j2)): + alignment.append((r, h)) + elif tag == 'delete': + for r in range(i1, i2): + alignment.append((r, None)) + elif tag == 'insert': + for h in range(j1, j2): + alignment.append((None, h)) + return alignment + + +def find_tail_word_time( + ref_text: List[str], + pred_text: List[str], + merged_words: List[WordModel], + char2word_idx: List[int], + alignment: List[Tuple[Optional[int], Optional[int]]], +) -> Optional[WordModel]: + # alignment = align_texts(ref_text, merged_text) + """ + 根据标答文本 ref_text 找尾字(非标点) + 通过 alignment 找对应合成文本 pred_text 尾字索引 + 再通过 char2word_idx 找对应word索引,返回对应的 WordModel + + :param ref_text: 标答文本字符列表 + :param pred_text: 合成文本字符列表 + :param merged_words: 合成文本对应的WordModel列表 + :param char2word_idx: 合成文本每个字符对应的WordModel索引 + :param alignment: ref_text和pred_text的字符对齐列表 (ref_idx, hyp_idx) + :return: 对应尾字的WordModel 或 None + """ + punct_set = set(",。!?、,.!?;;") + + # 1. 找到ref_text中最后一个非标点字符的索引 tail_ref_idx + tail_ref_idx = len(ref_text) - 1 + while tail_ref_idx >= 0 and ref_text[tail_ref_idx] in punct_set: + tail_ref_idx -= 1 + + if tail_ref_idx < 0: + # 全是标点,找不到尾字 + return None + + # 2. 找 alignment 中 ref_idx == tail_ref_idx 对应的 hyp_idx + tail_hyp_idx = None + for ref_idx, hyp_idx in reversed(alignment): + if ref_idx == tail_ref_idx and hyp_idx is not None: + tail_hyp_idx = hyp_idx + break + + if tail_hyp_idx is None: + # 没有对应的hyp_idx + return None + + # 3. hyp_idx 超出范围 + if tail_hyp_idx >= len(char2word_idx): + return None + + # 4. 通过 char2word_idx 找对应 word 索引 + word_index = char2word_idx[tail_hyp_idx] + + if word_index >= len(merged_words): + return None + + # 5. 返回对应的WordModel + return merged_words[word_index] + + +def find_head_word_time( + ref_text: List[str], + pred_text: List[str], + merged_words: List[WordModel], + char2word_idx: List[int], + alignment: List[Tuple[Optional[int], Optional[int]]], +) -> Optional[WordModel]: + """ + 找标答首字在ASR中的start_time + + 参数: + ref_text:标答完整文本 + merged_text:ASR合并后的完整文本(逐字) + merged_words:ASR合并的WordModel列表 + char2word_idx:字符到词的映射索引列表 + + 返回: + 找到的首字对应词的start_time(毫秒),没找到返回None + """ + # alignment = align_texts(ref_text, merged_text) + + ref_head_index = 0 # 首字索引固定0 + + for ref_idx, hyp_idx in alignment: + if ref_idx == ref_head_index and hyp_idx is not None: + if 0 <= hyp_idx < len(char2word_idx): + word_idx = char2word_idx[hyp_idx] + return merged_words[word_idx] + return None + + +def merge_asr_results(asr_list: List[SegmentModel]) -> Tuple[str, List[WordModel], List[int]]: + """ + 合并多个 ASRResultModel 成一个大文本和 word 列表,同时建立每个字符对应的 WordModel 索引 + 返回: + - 合并文本 merged_text + - WordModel 列表 merged_words + - 每个字符所在 WordModel 的索引 char2word_idx + """ + # merged_text = "" + # merged_words = [] + # char2word_idx = [] + # + # for asr in asr_list: + # if not asr.text or not asr.words: + # continue + # merged_text += asr.text + # for word in asr.words: + # word.segment = asr + # merged_words.append(word) + # for ch in word.text: + # char2word_idx.append(len(merged_words) - 1) + # return merged_text, merged_words, char2word_idx + """ + 合并多个 ASRResultModel 成一个大文本和 word 列表, + 去掉标点符号,建立每个字符对应的 WordModel 索引 + + 返回: + - 去标点后的合并文本 merged_text + - WordModel 列表 merged_words(包含标点) + - 去标点后的每个字符对应 WordModel 的索引 char2word_idx + """ + punct_set = set(",。!?、,.!?;;") # 需要过滤的标点集合 + + merged_text = "" + merged_words = [] + char2word_idx = [] + + for asr in asr_list: + if not asr.text or not asr.words: + continue + merged_words_start_len = len(merged_words) + for word in asr.words: + word.segment = asr + merged_words.append(word) + + # 遍历所有word,拼接时去掉标点,同时维护 char2word_idx + for idx_in_asr, word in enumerate(asr.words): + word_idx = merged_words_start_len + idx_in_asr + for ch in word.text: + if ch not in punct_set: + merged_text += ch + char2word_idx.append(word_idx) + + return merged_text, merged_words, char2word_idx + + +def rebuild_char2word_idx(pred_tokens: List[str], merged_words: List[WordModel]) -> List[int]: + """ + 重新构建 char2word_idx,使其与 pred_tokens 一一对应 + """ + char2word_idx = [] + word_char_idx = 0 + for word_idx, word in enumerate(merged_words): + for _ in word.text: + if word_char_idx < len(pred_tokens): + char2word_idx.append(word_idx) + word_char_idx += 1 + return char2word_idx + + +def build_hyp_token_to_asr_chart_index( + hyp_tokens: List[str], + asr_words: List[WordModel] +) -> List[int]: + """ + 建立从 hyp_token 索引到 asr_word 索引的映射 + 假设 asr_words 的 text 组成 hyp_tokens 的连续子串(简单匹配) + """ + hyp_to_asr_word_idx = [-1] * len(hyp_tokens) + + i_asr = 0 + i_hyp = 0 + + while i_asr < len(asr_words) and i_hyp < len(hyp_tokens): + asr_word = asr_words[i_asr].text + length = len(asr_word) + # 拼接 hyp_tokens 从 i_hyp 开始的 length 个 token + hyp_substr = "".join(hyp_tokens[i_hyp:i_hyp + length]) + if hyp_substr == asr_word: + # 匹配成功,建立映射 + for k in range(i_hyp, i_hyp + length): + hyp_to_asr_word_idx[k] = i_asr + i_hyp += length + i_asr += 1 + else: + # 如果不匹配,尝试扩大或缩小匹配长度(容错) + # 也可以根据具体情况改进此逻辑 + # 这里简化处理,跳过一个hyp token + i_hyp += 1 + + return hyp_to_asr_word_idx + +import re + +def normalize(text: str) -> str: + return re.sub(r"[^\w']+", '', text.lower()) # 去除非单词字符,保留撇号 + +def build_hyp_token_to_asr_word_index(hyp_tokens: List[str], asr_words: List[WordModel]) -> List[int]: + hyp_to_asr_word_idx = [-1] * len(hyp_tokens) + i_hyp, i_asr = 0, 0 + + while i_hyp < len(hyp_tokens) and i_asr < len(asr_words): + hyp_token = normalize(hyp_tokens[i_hyp]) + asr_word = normalize(asr_words[i_asr].text) + + # 匹配包含/前缀关系,提高鲁棒性 + if hyp_token == asr_word or hyp_token in asr_word or asr_word in hyp_token: + hyp_to_asr_word_idx[i_hyp] = i_asr + i_hyp += 1 + i_asr += 1 + else: + i_hyp += 1 + + return hyp_to_asr_word_idx + +def find_tail_word( + ref_tokens: List[str], # 参考文本token列表 + hyp_tokens: List[str], # 预测文本token列表 + alignment: List[Tuple[Optional[int], Optional[int]]], # (ref_idx, hyp_idx)对齐结果 + hyp_to_asr_word_idx: dict, + asr_words: List[WordModel], + punct_set: set = set(",。!?、,.!?;;") +) -> Optional[WordModel]: + """ + 通过参考文本尾token,定位对应预测token,再映射到ASR词,拿时间 + """ + + """ + 找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word) + """ + + # 1. 去掉 ref 尾部标点,找到 ref 尾词 index + tail_ref_idx = len(ref_tokens) - 1 + while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set: + tail_ref_idx -= 1 + if tail_ref_idx < 0: + return None + + # 2. 在 alignment 中找到对应的 hyp_idx + tail_hyp_idx = None + for ref_idx, hyp_idx in reversed(alignment): + if ref_idx == tail_ref_idx and hyp_idx is not None: + tail_hyp_idx = hyp_idx + break + + # 3. 如果找不到,退一步找最后一个有匹配的 ref_idx + if tail_hyp_idx is None: + for ref_idx, hyp_idx in reversed(alignment): + if hyp_idx is not None: + tail_hyp_idx = hyp_idx + break + + if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx): + return None + + # 4. 映射到 ASR word index + asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx] + if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words): + return None + + return asr_words[asr_word_idx] + + +def find_tail_word2( + ref_tokens: List[str], # 标答token列表 + hyp_tokens: List[str], # 预测token列表 + alignment: List[Tuple[Optional[int], Optional[int]]], # 对齐 (ref_idx, hyp_idx) + hyp_to_asr_word_idx: List[int], # hyp token 对应的 ASR word 索引 + asr_words: List[WordModel], + punct_set: set = set(",。!?、,.!?;;"), + enable_debug: bool = False +) -> Optional[WordModel]: + """ + 找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word) + + 返回 None 表示没找到 + """ + # Step 1. 找到 ref_tokens 中最后一个非标点的索引 + tail_ref_idx = len(ref_tokens) - 1 + while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set: + tail_ref_idx -= 1 + if tail_ref_idx < 0: + if enable_debug: + print("全是标点,尾字找不到") + return None + + # Step 2. alignment 中查找 tail_ref_idx 对应的 hyp_idx + tail_hyp_idx = None + for ref_idx, hyp_idx in reversed(alignment): + if ref_idx == tail_ref_idx and hyp_idx is not None: + tail_hyp_idx = hyp_idx + break + + # Step 3. fallback:如果找不到,向前找最近一个非标点且能对齐的 ref_idx + fallback_idx = tail_ref_idx + while tail_hyp_idx is None and fallback_idx >= 0: + if ref_tokens[fallback_idx] not in punct_set: + for ref_idx, hyp_idx in reversed(alignment): + if ref_idx == fallback_idx and hyp_idx is not None: + tail_hyp_idx = hyp_idx + break + fallback_idx -= 1 + + if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx): + if enable_debug: + print(f"tail_hyp_idx 无法找到或超出范围: {tail_hyp_idx}") + return None + + asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx] + if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words): + if enable_debug: + print(f"asr_word_idx 无效: {asr_word_idx}") + return None + + return asr_words[asr_word_idx] + + +def find_head_word( + ref_tokens: List[str], + hyp_tokens: List[str], + alignment: List[Tuple[Optional[int], Optional[int]]], + hyp_to_asr_word_idx: dict, + asr_words: List[WordModel], + punct_set: set = set(",。!?、,.!?;;") +) -> Optional[WordModel]: + """ + 通过参考文本开头第一个非标点token,定位对应预测token,再映射到ASR词,拿时间 + """ + + # 1. 找到参考文本开头非标点索引 + head_ref_idx = 0 + while head_ref_idx < len(ref_tokens) and ref_tokens[head_ref_idx] in punct_set: + head_ref_idx += 1 + if head_ref_idx >= len(ref_tokens): + return None + + # 2. 找到 alignment 中对应的 hyp_idx + head_hyp_idx = None + for ref_idx, hyp_idx in alignment: + if ref_idx == head_ref_idx and hyp_idx is not None: + head_hyp_idx = hyp_idx + break + + if head_hyp_idx is None or head_hyp_idx >= len(hyp_to_asr_word_idx): + return None + + # 3. 映射到 asr_words 的索引 + asr_word_idx = hyp_to_asr_word_idx[head_hyp_idx] + if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words): + return None + + return asr_words[asr_word_idx] + + +def calculate_sentence_delay( + datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str = "zh" +) -> (float, float, float): + """ + + :param datas: 标答和模型结果 + :return: 未找到尾字的比例,修正0的数量,平均延迟时间。 + """ + tail_offset_time = 0 # 尾字偏移时间 + standard_offset_time = 0 # 尾字偏移时间 + tail_not_found = 0 # 未找到尾字数量 + tail_found = 0 # 找到尾字数量 + + standard_fix = 0 + final_fix = 0 + + head_offset_time = 0 # 尾字偏移时间 + final_offset_time = 0 # 尾字偏移时间 + head_not_found = 0 # 未找到尾字数量 + head_found = 0 # 找到尾字数量 + + for audio_item, asr_list in datas: + if not audio_item.voice: + continue + # (以防万一)将标答中所有的文本连起来,并将标答中最后一条信息的结束时间作为结束时间。 + ref_text = "" + for voice in audio_item.voice: + ref_text = ref_text + voice.answer.strip() + if not ref_text: + continue + logger.debug(f"-=-=-=-=-=-=-=-=-=-=-=-=-=start-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=") + + ref_end_ms = int(audio_item.voice[-1].end * 1000) + ref_start_ms = int(audio_item.voice[0].start * 1000) + + # 录音所有的片段都过一下,text的end和receive的end + # 统计定稿时间。 接收到segment的时间 - segment.end_time + # + # pred_text = "" + # asr_words: List[WordModel] = [] + # for asr in asr_list: + # pred_text = pred_text + asr.text + # final_offset = asr.receive_time - asr.end_time + # asr_words = asr_words + asr.words + # for word in asr.words: + # word.segment = asr + # if final_offset < 0: + # final_fix = final_fix + 1 + # # 统计被修复的数量 + # final_offset = 0 + # # 统计定稿偏移时间 + # final_offset_time = final_offset_time + final_offset + + pred_text = [] + asr_words: List[WordModel] = [] + temp_final_offset_time = 0 + for asr in asr_list: + pred_text = pred_text + [word.text for word in asr.words] + final_offset = asr.receive_time - asr.end_time + logger.debug(f"asr.receive_time {asr.receive_time} , asr.end_time {asr.end_time} , final_offset {final_offset}") + asr_words = asr_words + asr.words + for word in asr.words: + word.segment = asr + if final_offset < 0: + final_fix = final_fix + 1 + # 统计被修复的数量 + final_offset = 0 + # 统计定稿偏移时间 + temp_final_offset_time = temp_final_offset_time + final_offset + final_offset_time = final_offset_time + temp_final_offset_time / len(asr_list) + + # 处理模型给出的结果。 + logger.debug(f"text: {ref_text},pred_text: {pred_text}") + # 计算对应关系 + + # pred_tokens 是与原文一致的,只是可能多了几个为空的位置。需要打平为一维数组,并记录对应的word的位置。 + flat_pred_tokens = [] + hyp_to_asr_word_idx = {} # key: flat_pred_token_index -> asr_word_index + + alignment_type = get_alignment_type(language) + if alignment_type == "chart": + label_tokens = Tokenizer.tokenize([ref_text], language)[0] + pred_tokens = Tokenizer.tokenize(pred_text, language) + for asr_idx, token_group in enumerate(pred_tokens): + for token in token_group: + flat_pred_tokens.append(token) + hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx + alignment = align_texts(label_tokens, "".join(flat_pred_tokens)) + else: + label_tokens = Tokenizer.norm_and_tokenize([ref_text], language)[0] + pred_tokens = Tokenizer.norm_and_tokenize(pred_text, language) + for asr_idx, token_group in enumerate(pred_tokens): + for token in token_group: + flat_pred_tokens.append(token) + hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx + alignment = align_tokens(label_tokens, flat_pred_tokens) + + logger.debug(f"ref_tokens: {label_tokens}") + logger.debug(f"pred_tokens: {pred_tokens}") + logger.debug(f"alignment sample: {alignment[:30]}") # 只打印前30个,避免日志过大 + logger.debug(f"hyp_to_asr_word_idx: {hyp_to_asr_word_idx}") + + head_word_info = find_head_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words) + + if head_word_info is None: + # 统计没有找到首字的数量 + head_not_found = head_not_found + 1 + logger.debug(f"未找到首字") + else: + logger.debug(f"head_word: {head_word_info.text} ref_start_ms:{ref_start_ms}") + # 找到首字 + # 统计首字偏移时间 首字在策略中出现的word的时间 - 标答start_time + head_offset_time = head_offset_time + abs(head_word_info.start_time - ref_start_ms) + # 统计找到首字的数量 + head_found += 1 + + # 找尾字所在的模型返回words信息。 + + tail_word_info = find_tail_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words) + if tail_word_info is None: + # 没有找到尾字,记录数量 + tail_not_found = tail_not_found + 1 + logger.debug(f"未找到尾字") + else: + # 找到尾字了 + logger.debug(f"tail_word: {tail_word_info.text} ref_end_ms: {ref_end_ms}") + # 统计尾字偏移时间 标答的end_time - 策略尾字所在word的end_time + tail_offset_time = abs(ref_end_ms - tail_word_info.end_time) + tail_offset_time + + # 统计标答句延迟时间 策略尾字所在word的实际接收时间 - 标答句end时间 + standard_offset = tail_word_info.segment.receive_time - ref_end_ms + logger.debug(f"tail_word_info.segment.receive_time {tail_word_info.segment.receive_time } , tail_word_info.end_time {tail_word_info.end_time} , ref_end_ms {ref_end_ms}") + # 如果小于0修正为0 + if standard_offset < 0: + standard_offset = 0 + # 统计被修正的数量 + standard_fix = standard_fix + 1 + standard_offset_time = standard_offset + standard_offset_time + + # 统计找到尾字的数量 + tail_found += 1 + + logger.info(f"-=-=-=-=-=-=-=-=-=-=-=-=-=end-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=") + + logger.debug( + f"找到首字字数量: {head_found},未找到首字数量:{head_not_found},找到尾字数量: {tail_found},未找到尾字数量:{tail_not_found},修正标答偏移负数数量:{standard_fix},修正定稿偏移负数数量:{final_fix},") + logger.debug( + f"尾字偏移总时间:{tail_offset_time},标答句偏移总时间:{standard_offset_time},首字偏移总时间:{head_offset_time},定稿偏移总时间:{final_offset_time},") + + # + # 统计平均值 + head_not_found_ratio = head_not_found / (head_found + head_not_found) + tail_not_found_ratio = tail_not_found / (tail_found + tail_not_found) + + average_tail_offset = tail_offset_time / tail_found / 1000 + average_head_offset = head_offset_time / head_found / 1000 + + average_standard_offset = standard_offset_time / tail_found / 1000 + average_final_offset = final_offset_time / tail_found / 1000 + + logger.info( + f"首字未找到比例:{head_not_found_ratio},尾字未找到比例:{tail_not_found_ratio},首字偏移时间:{average_head_offset},尾字偏移时间:{average_tail_offset},标答句偏移时间:{average_standard_offset},定稿偏移时间:{average_final_offset}") + + return head_not_found_ratio, average_head_offset, tail_not_found_ratio, average_standard_offset, average_final_offset, average_tail_offset + + +if __name__ == '__main__': + checks = [ + { + "type": "zh", + "ref": "今天天气真好", + "hyp": "今天真好" + }, + { + "type": "zh", + "ref": "我喜欢吃苹果", + "hyp": "我很喜欢吃香蕉" + }, + { + "type": "zh", + "ref": "我喜欢吃苹果", + "hyp": "我喜欢吃苹果" + }, + { + "type": "en", + "ref": "I like to eat apples", + "hyp": "I really like eating apples" + }, + { + "type": "en", + "ref": "She is going to the market", + "hyp": "She went market" + }, + { + "type": "en", + "ref": "Hello world", + "hyp": "Hello world" + }, + { + "type": "en", + "ref": "Good morning", + "hyp": "Bad night" + }, + ] + for check in checks: + ref = check.get("ref") + type = check.get("type") + hyp = check.get("hyp") + + res1 = align_texts(ref, hyp) + + res2 = align_tokens(list(ref), list(hyp)) + + from utils.tokenizer import Tokenizer + + start = time.time() + tokens_pred = Tokenizer.norm_and_tokenize([ref], type) + print(time.time() - start) + + start = time.time() + Tokenizer.norm_and_tokenize([ref + ref + ref], type) + print(time.time() - start) + + tokens_label = Tokenizer.norm_and_tokenize([hyp], type) + print(tokens_pred) + print(tokens_label) + res3 = align_tokens(tokens_pred[0], tokens_label[0]) + print(res1 == res2) + print(res1 == res3) diff --git a/utils/client.py b/utils/client.py new file mode 100644 index 0000000..17fb56c --- /dev/null +++ b/utils/client.py @@ -0,0 +1,195 @@ +import json, os, threading, time, traceback +from typing import ( + Any, List +) +from copy import deepcopy +from utils.logger import logger + +from websocket import ( + create_connection, + WebSocketConnectionClosedException, + ABNF +) + +from utils.model import ( + ASRResponseModel, + SegmentModel +) +import threading +import queue + +from pydantic_core import ValidationError + +_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None + + +class ASRWebSocketClient: + def __init__(self, url: str): + self.endpoint = f"{url}/recognition" + # self.ctx = deepcopy(ctx) + self.ws = None + self.conn_attempts = -1 + self.failed = False + self.terminate_time = float("inf") + self.sent_timestamps: List[float] = [] + self.received_timestamps: List[float] = [] + self.results: List[Any] = [] + self.connected = True + logger.info(f"Target endpoint: {self.endpoint}") + + def execute(self, path: str) -> List[SegmentModel]: + # 开启线程,一个发,一个接要记录接收数据的时间。 + send_thread = threading.Thread(target=self.send, args=(path,)) + # 记录开始时间 + start_time = time.time() + # 启动线程发送数据 + send_thread.start() + + # 用来返回结果的队列 + result_queue = queue.Queue() + + # 线程封装 receive + def receive_thread_fn(): + try: + res = self.receive(start_time) + result_queue.put(res) + except Exception as e: + # 放异常信息 + result_queue.put(None) + + # 启动 receive 线程(关注返回值) + receive_thread = threading.Thread(target=receive_thread_fn) + receive_thread.start() + + # 可选:等待 send 线程也结束(可删) + send_thread.join() + receive_thread.join() + + # 等待 receive 完成,并获取返回值 + result = result_queue.get() + + return result + + def initialize_connection(self, language: str): + expiration = time.time() + float(os.getenv("end_time", "2")) + self.connected = False + init = False + while True: + try: + if not init: + logger.debug(f"建立ws链接:发送建立ws链接请求。") + self.ws = create_connection(self.endpoint) + body = json.dumps(self._get_init_payload(language)) + logger.debug(f"建立ws链接:发送初始化数据。{body}") + self.ws.send(body) + init = True + msg = self.ws.recv() + logger.debug(f"收到响应数据: {msg}") + if len(msg) == 0: + time.sleep(0.5) # 睡眠一下等待数据写回来 + continue + if isinstance(msg, str): + try: + msg = json.loads(msg) + except Exception: + raise Exception("建立ws链接:响应数据非json格式!") + if isinstance(msg, dict): + connected = msg.get("success") + if connected: + logger.debug("建立ws链接:链接建立成功!") + self.conn_attempts = 0 + self.connected = True + return + else: + logger.info("建立ws链接:链接建立失败!") + init = False + self.conn_attempts = self.conn_attempts + 1 + if self.conn_attempts > 5: + raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。") + if time.time() > expiration: + raise RuntimeError("建立ws链接:链接建立超时!") + + except WebSocketConnectionClosedException or TimeoutError: + raise Exception("建立ws链接:初始化阶段连接中断,退出。") + except Exception as e: + logger.info("建立ws链接:链接建立失败!") + init = False + self.conn_attempts = self.conn_attempts + 1 + if self.conn_attempts > 5: + raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。") + if time.time() > expiration: + raise RuntimeError("建立ws链接:链接建立超时!") + + def shutdown(self): + try: + if self.ws: + self.ws.close() + except Exception: + pass + self.connected = False + + def _get_init_payload(self, language="zh") -> dict: + # language = "zh" + return { + "language": language + } + + def _get_finish_payload(self) -> dict: + return { + "end": "true" + } + + def send(self, path): + skip_wav_header = path.endswith("wav") + with open(path, "rb") as f: + if skip_wav_header: + # WAV 文件头部为 44 字节 + f.read(44) + + while chunk := f.read(3200): + logger.debug(f"发送 {len(chunk)} 字节数据.") + self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY) + time.sleep(0.1) + self.ws.send(json.dumps(self._get_finish_payload())) + + def receive(self, start_time) -> [SegmentModel]: + results = [] + while True: + msg = self.ws.recv() + # 记录读取到数据的时间(毫秒值) + now = 1000 * (time.time() - start_time) + logger.debug(f"{now} 收到响应数据: {msg}") + res = json.loads(msg) + if res.get("asr_results"): + item = ASRResponseModel.model_validate_json(msg).asr_results + item.receive_time = now + results.append(item) + # logger.info(item.summary()) + else: + logger.debug(f"响应结束") + # 按照para_seq排序,并检查一下序号是否连续 + results.sort(key=lambda x: x.para_seq) + + missing_seqs = [] + for i in range(len(results) - 1): + expected_next = results[i].para_seq + 1 + actual_next = results[i + 1].para_seq + if actual_next != expected_next: + missing_seqs.extend(range(expected_next, actual_next)) + + if missing_seqs: + logger.warning(f"检测到丢失的段落序号:{missing_seqs}") + else: + logger.debug("响应数据正常") + + return results + + +if __name__ == '__main__': + ws_client = ASRWebSocketClient("ws://localhost:18000") + ws_client.initialize_connection("zh") + + res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav") + for i in res: + print(i.summary()) + print() diff --git a/utils/helm.py b/utils/helm.py new file mode 100644 index 0000000..9629476 --- /dev/null +++ b/utils/helm.py @@ -0,0 +1,331 @@ +import copy +import io +import json +import logging +import os +import tarfile +import time +from collections import defaultdict + +import requests +from ruamel.yaml import YAML + +from typing import Dict, Any + +sut_chart_root = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "helm-chart", "sut" +) +logger = logging.getLogger(__file__) +lb_headers = ( + {"Authorization": "Bearer " + os.getenv("LEADERBOARD_API_TOKEN", "")} + if os.getenv("LEADERBOARD_API_TOKEN") + else None +) +pull_num: defaultdict = defaultdict() + +JOB_ID = int(os.getenv("JOB_ID", "-1")) +assert JOB_ID != -1 +LOAD_SUT_URL = os.getenv("LOAD_SUT_URL") +assert LOAD_SUT_URL is not None +GET_JOB_SUT_INFO_URL = os.getenv("GET_JOB_SUT_INFO_URL") +assert GET_JOB_SUT_INFO_URL is not None + + +def apply_env_to_values(values, envs): + if "env" not in values: + values["env"] = [] + old_key_list = [x["name"] for x in values["env"]] + for k, v in envs.items(): + try: + idx = old_key_list.index(k) + values["env"][idx]["value"] = v + except ValueError: + values["env"].append({"name": k, "value": v}) + return values + + +def merge_values(base_value, incr_value): + if isinstance(base_value, dict) and isinstance(incr_value, dict): + for k in incr_value: + base_value[k] = ( + merge_values(base_value[k], incr_value[k]) + if k in base_value + else incr_value[k] + ) + elif isinstance(base_value, list) and isinstance(incr_value, list): + base_value.extend(incr_value) + else: + base_value = incr_value + return base_value + + +def gen_chart_tarball(docker_image, policy=None): + """docker image加上digest并根据image生成helm chart包, 失败直接异常退出 + + Args: + docker_image (_type_): docker image + + Returns: + tuple[BytesIO, dict]: [helm chart包file对象, values内容] + """ + # load values template + with open(os.path.join(sut_chart_root, "values.yaml.tmpl")) as fp: + yaml = YAML(typ="rt") + values = yaml.load(fp) + # update docker_image + get_image_hash_url = os.getenv("GET_IMAGE_HASH_URL", None) + if get_image_hash_url is not None: + # convert tag to hash for docker_image + resp = requests.get( + get_image_hash_url, + headers=lb_headers, + params={"image": docker_image}, + timeout=600, + ) + assert resp.status_code == 200, ( + "Convert tag to hash for docker image failed, API retcode %d" + % resp.status_code + ) + resp = resp.json() + assert resp[ + "success" + ], "Convert tag to hash for docker image failed, response: %s" % str(resp) + token = resp["data"]["image"].rsplit(":", 2) + assert len(token) == 3, "Invalid docker image %s" % resp["data"]["image"] + values["image"]["repository"] = token[0] + values["image"]["tag"] = ":".join(token[1:]) + else: + token = docker_image.rsplit(":", 1) + if len(token) != 2: + raise RuntimeError("Invalid docker image %s" % docker_image) + values["image"]["repository"] = token[0] + values["image"]["tag"] = token[1] + values["image"]["pullPolicy"] = policy + # output values.yaml + with open(os.path.join(sut_chart_root, "values.yaml"), "w") as fp: + yaml = YAML(typ="rt") + yaml.dump(values, fp) + # tarball + tarfp = io.BytesIO() + with tarfile.open(fileobj=tarfp, mode="w:gz") as tar: + tar.add( + sut_chart_root, arcname=os.path.basename(sut_chart_root), recursive=True + ) + tarfp.seek(0) + logger.debug(f"Generated chart using values: {values}") + return tarfp, values + + +def deploy_chart( + name_suffix, + readiness_timeout, + chart_str=None, + chart_fileobj=None, + extra_values=None, + restart_count_limit=3, + pullimage_count_limit=3, +): + """部署sut, 失败直接异常退出 + + Args: + name_suffix (str): 同一个job有多个sut时, 区分不同sut的名称 + readiness_timeout (int): readiness超时时间, 单位s + chart_str (int, optional): chart url, 不为None则忽略chart_fileobj. Defaults to None. + chart_fileobj (BytesIO, optional): helm chart包file对象, chart_str不为None使用. Defaults to None. + extra_values (dict, optional): helm values的补充内容. Defaults to None. + restart_count_limit (int, optional): sut重启次数限制, 超出则异常退出. Defaults to 3. + pullimage_count_limit (int, optional): image拉取次数限制, 超出则异常退出. Defaults to 3. + + Returns: + tuple[str, str]: [用于访问服务的k8s域名, 用于unload_sut的名称] + """ + logger.info( + f"Deploying SUT application for JOB {JOB_ID}, name_suffix {name_suffix}, extra_values {extra_values}" + ) + # deploy + payload = { + "job_id": JOB_ID, + "resource_name": name_suffix, + "priorityclassname": os.environ.get("priorityclassname"), + } + extra_values = {} if not extra_values else extra_values + payload["values"] = json.dumps(extra_values, ensure_ascii=False) + if chart_str is not None: + payload["helm_chart"] = chart_str + resp = requests.post( + LOAD_SUT_URL, + data=payload, + headers=lb_headers, + timeout=600, + ) + else: + assert ( + chart_fileobj is not None + ), "Either chart_str or chart_fileobj should be set" + resp = requests.post( + LOAD_SUT_URL, + data=payload, + headers=lb_headers, + files=[("helm_chart_file", (name_suffix + ".tgz", chart_fileobj))], + timeout=600, + ) + if resp.status_code != 200: + raise RuntimeError( + "Failed to deploy application status_code %d %s" + % (resp.status_code, resp.text) + ) + resp = resp.json() + if not resp["success"]: + raise RuntimeError("Failed to deploy application response %r" % resp) + service_name = resp["data"]["service_name"] + sut_name = resp["data"]["sut_name"] + logger.info(f"SUT application deployed with service_name {service_name}") + # waiting for appliation ready + running_at = None + while True: + retry_interval = 10 + logger.info( + f"Waiting {retry_interval} seconds to check whether SUT application {service_name} is ready..." + ) + time.sleep(retry_interval) + check_result, running_at = check_sut_ready_from_resp( + service_name, + running_at, + readiness_timeout, + restart_count_limit, + pullimage_count_limit, + ) + if check_result: + break + + logger.info( + f"SUT application for JOB {JOB_ID} name_suffix {name_suffix} is ready, service_name {service_name}" + ) + return service_name, sut_name + + +def check_sut_ready_from_resp( + service_name, + running_at, + readiness_timeout, + restart_count_limit, + pullimage_count_limit, +): + try: + resp = requests.get( + f"{GET_JOB_SUT_INFO_URL}/{JOB_ID}", + headers=lb_headers, + params={"with_detail": True}, + timeout=600, + ) + except Exception as e: + logger.warning( + f"Exception occured while getting SUT application {service_name} status", e + ) + return False, running_at + if resp.status_code != 200: + logger.warning( + f"Get SUT application {service_name} status failed with status_code {resp.status_code}" + ) + return False, running_at + resp = resp.json() + if not resp["success"]: + logger.warning( + f"Get SUT application {service_name} status failed with response {resp}" + ) + return False, running_at + if len(resp["data"]["sut"]) == 0: + logger.warning("Empty SUT application status") + return False, running_at + resp_data_sut = copy.deepcopy(resp["data"]["sut"]) + for status in resp_data_sut: + del status["detail"] + logger.info(f"Got SUT application status: {resp_data_sut}") + for status in resp["data"]["sut"]: + if status["phase"] in ["Succeeded", "Failed"]: + raise RuntimeError( + f"Some pods of SUT application {service_name} terminated with status {status}" + ) + elif status["phase"] in ["Pending", "Unknown"]: + return False, running_at + elif status["phase"] != "Running": + raise RuntimeError( + f"Unexcepted pod status {status} of SUT application {service_name}" + ) + if running_at is None: + running_at = time.time() + for ct in status["detail"]["status"]["container_statuses"]: + if ct["restart_count"] > 0: + logger.info( + f"pod {status['pod_name']} restart count = {ct['restart_count']}" + ) + if ct["restart_count"] > restart_count_limit: + raise RuntimeError( + f"pod {status['pod_name']} restart too many times(over {restart_count_limit})" + ) + if ( + ct["state"]["waiting"] is not None + and "reason" in ct["state"]["waiting"] + and ct["state"]["waiting"]["reason"] + in ["ImagePullBackOff", "ErrImagePull"] + ): + pull_num[status["pod_name"]] += 1 + logger.info( + "pod %s has {pull_num[status['pod_name']]} times inspect pulling image info: %s" + % (status["pod_name"], ct["state"]["waiting"]) + ) + if pull_num[status["pod_name"]] > pullimage_count_limit: + raise RuntimeError(f"pod {status['pod_name']} cannot pull image") + if not status["conditions"]["Ready"]: + if running_at is not None and time.time() - running_at > readiness_timeout: + raise RuntimeError( + f"SUT Application readiness has exceeded readiness_timeout:{readiness_timeout}s" + ) + return False, running_at + return True, running_at + + +def resource_check(values: Dict[str, Any], image: str): + # 补充镜像信息。 + values["docker_image"] = image + values["image_pull_policy"] = "IfNotPresent" + values["image_pull_policy"] = "Always" + # values["command"] = ["python", "run.py"] + # 补充resources限制 + values["resources"] = { + "limits": { + "cpu": 4, + "memory": "16Gi", + "iluvatar.ai/gpu": "1" + }, + "requests": { + "cpu": 4, + "memory": "16Gi", + "iluvatar.ai/gpu": "1" + }, + } + + values["nodeSelector"] = { + "contest.4pd.io/accelerator": "iluvatar-BI-V100" + } + values["tolerations"] = [ + { + "key": "hosttype", + "operator": "Equal", + "value": "iluvatar", + "effect": "NoSchedule", + } + ] + + """ + nodeSelector: + contest.4pd.io/accelerator: iluvatar-BI-V100 + tolerations: + - key: hosttype + operator: Equal + value: iluvatar + effect: NoSchedule + """ + + # TODO 补充选择规则 + return values diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..244410c --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +import logging +import os + +level = logging.INFO +level_str = "INFO" + +# level = logging.DEBUG +# level_str = "DEBUG" + +logging.basicConfig( + format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", level_str), +) +logger = logging.getLogger(__file__) + +# another logger + +log = logging.getLogger("detailed_logger") + +log.propagate = False + + + +log.setLevel(level) + +formatter = logging.Formatter( + "[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s", + "%Y-%m-%d %H:%M:%S", +) + +streamHandler = logging.StreamHandler() +streamHandler.setLevel(level) +streamHandler.setFormatter(formatter) +log.addHandler(streamHandler) diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..8be1f55 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,320 @@ +# coding: utf-8 + +import os +from collections import Counter +from copy import deepcopy +from typing import List, Tuple + +import Levenshtein +import numpy as np +from schemas.context import ASRContext +from utils.logger import logger +from utils.tokenizer import Tokenizer, TokenizerType +from utils.update_submit import change_product_available + +IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None + + +def text_align(context: ASRContext) -> Tuple: + start_end_count = 0 + + label_start_time_list = [] + label_end_time_list = [] + for label_item in context.labels: + label_start_time_list.append(label_item.start) + label_end_time_list.append(label_item.end) + pred_start_time_list = [] + pred_end_time_list = [] + sentence_start = True + for pred_item in context.preds: + if sentence_start: + pred_start_time_list.append(pred_item.recognition_results.start_time) + if pred_item.recognition_results.final_result: + pred_end_time_list.append(pred_item.recognition_results.end_time) + sentence_start = pred_item.recognition_results.final_result + # check start0 < end0 < start1 < end1 < start2 < end2 - ... + if IN_TEST: + print(pred_start_time_list) + print(pred_end_time_list) + pred_time_list = [] + i, j = 0, 0 + while i < len(pred_start_time_list) and j < len(pred_end_time_list): + pred_time_list.append(pred_start_time_list[i]) + pred_time_list.append(pred_end_time_list[j]) + i += 1 + j += 1 + if i < len(pred_start_time_list): + pred_time_list.append(pred_start_time_list[-1]) + for i in range(1, len(pred_time_list)): + # 这里给个 600ms 的宽限 + if pred_time_list[i] < pred_time_list[i - 1] - 0.6: + logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...") + logger.error( + f"当前识别的每个句子开始和结束时间分别为: \ + 开始时间:{pred_start_time_list}, \ + 结束时间:{pred_end_time_list}" + ) + start_end_count += 1 + # change_product_available() + # 时间前后差值 300ms 范围内 + start_time_align_count = 0 + end_time_align_count = 0 + for label_start_time in label_start_time_list: + for pred_start_time in pred_start_time_list: + if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3: + start_time_align_count += 1 + break + for label_end_time in label_end_time_list: + for pred_end_time in pred_end_time_list: + if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3: + end_time_align_count += 1 + break + logger.info( + f"start-time 对齐个数 {start_time_align_count}, \ + end-time 对齐个数 {end_time_align_count}\ + 数据集中句子总数 {len(label_start_time_list)}" + ) + return start_time_align_count, end_time_align_count, start_end_count + + +def first_delay(context: ASRContext) -> Tuple: + first_send_time = context.preds[0].send_time + first_delay_list = [] + sentence_start = True + for pred_context in context.preds: + if sentence_start: + sentence_begin_time = pred_context.recognition_results.start_time + first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time + first_delay_list.append(first_delay_time) + sentence_start = pred_context.recognition_results.final_result + if IN_TEST: + print(f"当前音频的首字延迟为{first_delay_list}") + logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s") + return np.sum(first_delay_list), len(first_delay_list) + + +def revision_delay(context: ASRContext): + first_send_time = context.preds[0].send_time + revision_delay_list = [] + for pred_context in context.preds: + if pred_context.recognition_results.final_result: + sentence_end_time = pred_context.recognition_results.end_time + revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time + revision_delay_list.append(revision_delay_time) + + if IN_TEST: + print(revision_delay_list) + logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s") + return np.sum(revision_delay_list), len(revision_delay_list) + + +def patch_unique_token_count(context: ASRContext): + # print(context.__dict__) + # 对于每一个返回的结果都进行 tokenize + pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds] + pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang) + # print(pred_text_list) + # print(pred_text_tokenized_list) + + # 判断当前是否修改了超过 3s 内的 token 数目 + ## 当前句子的最开始接受时间 + first_recv_time = None + ## 不可修改的 token 个数 + unmodified_token_cnt = 0 + ## 3s 的 index 位置 + time_token_idx = 0 + ## 当前是句子的开始 + final_sentence = True + + ## 修改了不可修改的范围 + is_unmodified_token = False + + for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)): + ## 当前是句子的第一次返回 + if final_sentence: + first_recv_time = pred_context.recv_time + unmodified_token_cnt = 0 + time_token_idx = idx + final_sentence = pred_context.recognition_results.final_result + continue + final_sentence = pred_context.recognition_results.final_result + ## 当前 pred 的 recv-time + pred_recv_time = pred_context.recv_time + ## 最开始 3s 直接忽略 + if pred_recv_time - first_recv_time < 3: + continue + ## 根据历史返回信息,获得最长不可修改长度 + while time_token_idx < idx: + context_pred_tmp = context.preds[time_token_idx] + context_pred_tmp_recv_time = context_pred_tmp.recv_time + tmp_tokens = pred_text_tokenized_list[time_token_idx] + if pred_recv_time - context_pred_tmp_recv_time >= 3: + unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens)) + time_token_idx += 1 + else: + break + ## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token + last_tokens = pred_text_tokenized_list[idx - 1] + if context.lang in ['ar', 'he']: + tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1] + continue + else: + tokens_check_pre, tokens_check_now = last_tokens, now_tokens + for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]): + if token_a != token_b: + is_unmodified_token = True + break + + if is_unmodified_token and int(os.getenv('test', 0)): + logger.error( + f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}" + ) + if is_unmodified_token: + break + + if is_unmodified_token: + logger.error("修改了不可修改的文字范围") + # change_product_available() + if int(os.getenv('test', 0)): + final_result = True + result_list = [] + for tokens, pred in zip(pred_text_tokenized_list, context.preds): + if final_result: + result_list.append([]) + result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time)) + final_result = pred.recognition_results.final_result + for item in result_list: + logger.info(str(item)) + + # 记录每个 patch 的 token 个数 + patch_unique_cnt_counter = Counter() + patch_unique_cnt_in_one_sentence = set() + for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds): + token_cnt = len(pred_text_tokenized) + patch_unique_cnt_in_one_sentence.add(token_cnt) + if pred_context.recognition_results.final_result: + for unique_cnt in patch_unique_cnt_in_one_sentence: + patch_unique_cnt_counter[unique_cnt] += 1 + patch_unique_cnt_in_one_sentence.clear() + if context.preds and not context.preds[-1].recognition_results.final_result: + for unique_cnt in patch_unique_cnt_in_one_sentence: + patch_unique_cnt_counter[unique_cnt] += 1 + # print(patch_unique_cnt_counter) + logger.info( + f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \ + 当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}" + ) + return patch_unique_cnt_counter + + +def mean_on_counter(counter: Counter): + total_sum = sum(key * count for key, count in counter.items()) + total_count = sum(counter.values()) + return total_sum * 1.0 / total_count + + +def var_on_counter(counter: Counter): + total_sum = sum(key * count for key, count in counter.items()) + total_count = sum(counter.values()) + mean = total_sum * 1.0 / total_count + return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count + + +def edit_distance(arr1: List, arr2: List): + operations = Levenshtein.editops(arr1, arr2) + i = sum([1 for operation in operations if operation[0] == "insert"]) + s = sum([1 for operation in operations if operation[0] == "replace"]) + d = sum([1 for operation in operations if operation[0] == "delete"]) + c = len(arr1) - s - d + return s, d, i, c + + +def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]): + """输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt""" + insert = sum(1 for item in tokens_gt_mapping if item is None) + delete = sum(1 for item in tokens_dt_mapping if item is None) + equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt) + replace = len(tokens_gt_mapping) - insert - equal + + token_count = replace + equal + delete + cer_value = (replace + delete + insert) * 1.0 / token_count + logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}") + return 1 - cer_value, token_count + + +def cut_rate( + tokens_gt: List[List[str]], + tokens_dt: List[List[str]], + tokens_gt_mapping: List[str], + tokens_dt_mapping: List[str], +): + sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping) + sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping) + sentence_final_token_index_gt = set(sentence_final_token_index_gt) + sentence_final_token_index_dt = set(sentence_final_token_index_dt) + sentence_count_gt = len(sentence_final_token_index_gt) + miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt) + more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt) + rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0) + return rate, sentence_count_gt, miss_count, more_count + + +def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]: + arr1 = deepcopy(tokens_gt) + arr2 = deepcopy(tokens_dt) + operations = Levenshtein.editops(arr1, arr2) + for op in operations[::-1]: + if op[0] == "insert": + arr1.insert(op[1], None) + elif op[0] == "delete": + arr2.insert(op[2], None) + return arr1, arr2 + + +def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]: + """获得原句子中每个句子尾部 token 的 index""" + token_index_list = [] + token_index = 0 + for token_in_one_sentence in tokens: + for _ in range(len(token_in_one_sentence)): + while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None: + token_index += 1 + token_index += 1 + token_index_list.append(token_index - 1) + return token_index_list + + +def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]: + """use self.cut_punc to cut all sentences, merge them and put them into list""" + sentence_cut_list = [] + for sentence in sentences: + sentence_list = [sentence] + sentence_tmp_list = [] + for punc in [ + "······", + "......", + "。", + ",", + "?", + "!", + ";", + ":", + "...", + ".", + ",", + "?", + "!", + ";", + ":", + ]: + for sentence in sentence_list: + sentence_tmp_list.extend(sentence.split(punc)) + sentence_list, sentence_tmp_list = sentence_tmp_list, [] + sentence_list = [item for item in sentence_list if item] + + if tokenizerType == TokenizerType.whitespace: + sentence_cut_list.append(" ".join(sentence_list)) + else: + sentence_cut_list.append("".join(sentence_list)) + + return sentence_cut_list diff --git a/utils/model.py b/utils/model.py new file mode 100644 index 0000000..2eead0f --- /dev/null +++ b/utils/model.py @@ -0,0 +1,63 @@ +from pydantic import BaseModel, Field +from typing import ( + Optional, + List, + Any +) + + +class WordModel(BaseModel): + text: str + start_time: int # 或 float,取决时间戳格式 + end_time: int + segment: Optional[Any] = Field(default=None, exclude=True) # 所属文段 + # receive_time: Optional[Any] = None # 所属文段接收到的时间偏移,这里为了处理时方便,记录了ASRResultModel中的receive_time + class Config: + fields = { + 'segment': {'exclude': True} + } + +class SegmentModel(BaseModel): + # 文段接收到的时间 + receive_time: Optional[Any] = None + language: str + para_seq: int + final_result: bool + text: str + start_time: int # 或者 float,如果时间戳是毫秒精度 + end_time: int + words: List[WordModel] # 补充 words 字段 + + def summary(self) -> str: + duration = (self.end_time - self.start_time) / 1000 # 秒 + return ( + f"\n" + f"language:{self.language} \n" + f"para_seq:{self.para_seq} \n" + f"final_result {self.final_result}\n" + f"text:{self.text}\n" + f"words:[{', '.join(w.text for w in self.words)}]\n" + f"start_time:{self.start_time}\n" + f"end_time:{self.end_time}\n" + ) + + +class ASRResponseModel(BaseModel): + asr_results: SegmentModel + + + +class VoiceSegment(BaseModel): + answer: str + start: float + end: float + + +class AudioItem(BaseModel): + audio_length: float + duration: Optional[float] = None + file: str + orig_file: Optional[str] = None + voice: List[VoiceSegment] + absolute_path: Optional[str] = None + diff --git a/utils/platform_tools.py b/utils/platform_tools.py new file mode 100644 index 0000000..1387076 --- /dev/null +++ b/utils/platform_tools.py @@ -0,0 +1,23 @@ +import os +from utils.logger import logger +import requests +import json + + +def mark_not_available(): + headers = {"Content-Type": "application/json"} + if os.getenv("LEADERBOARD_API_TOKEN"): + headers["Authorization"] = "Bearer " + os.getenv("LEADERBOARD_API_TOKEN") + + logger.info("更改为产品不可用...") + try: + submit_id = str(os.getenv("SUBMIT_ID", "-1")) + resp = requests.post( + os.getenv("UPDATE_SUBMIT_URL", "http://contest.4pd.io:8080/submit/update"), + data=json.dumps({submit_id: {"product_avaliable": 0}}), + headers=headers, + timeout=600, + ) + logger.info(resp.json()) + except Exception as e: + logger.error(f"change product available error, {e}") diff --git a/utils/reader.py b/utils/reader.py new file mode 100644 index 0000000..a689c25 --- /dev/null +++ b/utils/reader.py @@ -0,0 +1,138 @@ +import logging + +import yaml +from typing import ( + Tuple, + List, Any +) +from utils.model import AudioItem +import os +import zipfile +import tarfile +import gzip +import shutil +from pathlib import Path + + +def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]: + """ + 读取数据文件,返回语言和文本列表 + + 参数: + dataset_filepath (str): 数据文件路径 + + 返回: + Tuple[str, List[str]]: + - language: 文件中指定的语言字符串 + - datas: 文件中除语言行以外的文本列表(每行为一个元素) + """ + try: + # 认为都是压缩包,先解压数据。 + data_extract_path = "/tmp/datas" + data_yaml_path = extract_file(dataset_filepath, data_extract_path) + if not data_yaml_path: + raise ValueError(f"未找到数据集data.yaml文件。") + dataset_filepath = str(Path(data_yaml_path).parent.resolve()) + except Exception as e: + logging.exception(e) + + + with open(f"{dataset_filepath}/data.yaml") as f: + datas = yaml.safe_load(f) + language = datas.get("global", {}).get("lang", "zh") + query_data = datas.get("query_data", []) + + audios = [] + """ + - audio_length: 1.0099999999997635 + duration: 1.0099999999997635 + file: zh/0.wav + orig_file: ./112801_1/112801_1-631-772.wav + voice: + - answer: 好吧。 + end: 1.0099999999997635 + start: 0 + """ + for item in query_data: + audio = AudioItem.model_validate(item) + audio.absolute_path = f"{dataset_filepath}/{audio.file}" + audios.append(audio) + + return ( + language, + audios + ) + + +def extract_file(filepath: str, output_dir: str = ".") -> None: + """ + 将数据集解压到指定路径,返回data.yaml文件的路径 + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + data_yaml_path = None + # 硬编码一下 leaderboard_data_samples 数据集没有加拓展名 + # if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"): + # with zipfile.ZipFile(filepath, 'r') as zf: + # # 获取所有文件(非目录) + # all_files = [f for f in zf.namelist() if not f.endswith('/')] + # + # if not all_files: + # raise ValueError(f"数据集文件为空。{filepath}") + # + # # 获取公共路径前缀 + # parts_list = [Path(f).parts for f in all_files] + # common_parts = os.path.commonprefix(parts_list) + # strip_prefix_len = len(common_parts) + # + # for file in all_files: + # file_parts = Path(file).parts + # relative_parts = file_parts[strip_prefix_len:] + # dest_path = Path(output_dir).joinpath(*relative_parts) + # + # # 创建父目录 + # dest_path.parent.mkdir(parents=True, exist_ok=True) + # + # # 解压写入 + # with zf.open(file) as source, open(dest_path, "wb") as target: + # shutil.copyfileobj(source, target) + # + # # 检查是否是 data.yaml + # if Path(file).name == "data.yaml": + # data_yaml_path = str(dest_path.resolve()) + # + # logging.info(f"数据集解压成功。") + # else: + # raise ValueError(f"暂时不支持的压缩格式。{filepath}") + # TODO 使用的是已有的数据,都不是按照 zip结尾命名的,强制按照zip解压 + with zipfile.ZipFile(filepath, 'r') as zf: + # 获取所有文件(非目录) + all_files = [f for f in zf.namelist() if not f.endswith('/')] + + if not all_files: + raise ValueError(f"数据集文件为空。{filepath}") + + # 获取公共路径前缀 + parts_list = [Path(f).parts for f in all_files] + common_parts = os.path.commonprefix(parts_list) + strip_prefix_len = len(common_parts) + + for file in all_files: + file_parts = Path(file).parts + relative_parts = file_parts[strip_prefix_len:] + dest_path = Path(output_dir).joinpath(*relative_parts) + + # 创建父目录 + dest_path.parent.mkdir(parents=True, exist_ok=True) + + # 解压写入 + with zf.open(file) as source, open(dest_path, "wb") as target: + shutil.copyfileobj(source, target) + + # 检查是否是 data.yaml + if Path(file).name == "data.yaml": + data_yaml_path = str(dest_path.resolve()) + + logging.info(f"数据集解压成功。") + + return data_yaml_path diff --git a/utils/service.py b/utils/service.py new file mode 100644 index 0000000..ae59b33 --- /dev/null +++ b/utils/service.py @@ -0,0 +1,46 @@ +import os +import sys +from utils.helm import deploy_chart, gen_chart_tarball +from utils.logger import logger +import yaml + +UNIT_TEST = os.getenv("UNIT_TEST", 0) + + +def register_sut(st_config, resource_name, **kwargs): + st_config_values = st_config.get("values", {}) + docker_image = st_config_values["docker_image"] + image_pull_policy = st_config_values["image_pull_policy"] + chart_tar_fp, chart_values = gen_chart_tarball(docker_image, image_pull_policy) + sut_service_name, _ = deploy_chart( + resource_name, + int(os.getenv("readiness_timeout", 60 * 3)), + chart_fileobj=chart_tar_fp, + extra_values=st_config_values, + restart_count_limit=int(os.getenv('restart_count', 3)), + ) + chart_tar_fp.close() + sut_service_port = str(chart_values["service"]["port"]) + return "ws://{}:{}".format(sut_service_name, sut_service_port) + + +def start_server(submit_config_filepath: str, language: str): + resource_name = "model-server" + # 读取提交配置 & 修改配置信息 & 启动被测服务 + with open(submit_config_filepath, "r") as fp: + st_config = yaml.safe_load(fp) + from utils.helm import resource_check + if language == "zh": + image = "harbor-contest.4pd.io/yuxiaojie/judge_flow/asr-live-iluvatar/asr_engine_zh_semantic:contest-v0" + elif language == "en": + image = "harbor-contest.4pd.io/yuxiaojie/judge_flow/asr-live-iluvatar/asr_engine_en_semantic:contest-v0" + else: + image = "" + st_config["values"] = resource_check(st_config.get("values", {}), image) + sut_url = register_sut(st_config, resource_name) + print(sut_url) + return sut_url + + +if __name__ == '__main__': + start_server("/Users/yu/Documents/code-work/asr-live-iluvatar/script/config.yaml") diff --git a/utils/speechio/__init__.py b/utils/speechio/__init__.py new file mode 100644 index 0000000..dfb48db --- /dev/null +++ b/utils/speechio/__init__.py @@ -0,0 +1,3 @@ +''' +reference: https://github.com/SpeechColab/Leaderboard/tree/f287a992dc359d1c021bfc6ce810e5e36608e057/utils +''' diff --git a/utils/speechio/error_rate_en.py b/utils/speechio/error_rate_en.py new file mode 100644 index 0000000..352939f --- /dev/null +++ b/utils/speechio/error_rate_en.py @@ -0,0 +1,551 @@ +#!/usr/bin/env python3 +# coding=utf8 +# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab) + +import argparse +import csv +import json +import logging +import os +import sys +from typing import Iterable + +logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s') + +import pynini +from pynini.lib import pynutil + + +# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py +# to import original lib: +# from pynini.lib.edit_transducer import EditTransducer +class EditTransducer: + DELETE = "" + INSERT = "" + SUBSTITUTE = "" + + def __init__( + self, + symbol_table, + vocab: Iterable[str], + insert_cost: float = 1.0, + delete_cost: float = 1.0, + substitute_cost: float = 1.0, + bound: int = 0, + ): + # Left factor; note that we divide the edit costs by two because they also + # will be incurred when traversing the right factor. + sigma = pynini.union( + *[pynini.accep(token, token_type=symbol_table) for token in vocab], + ).optimize() + + insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2) + delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2)) + substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2)) + + edit = pynini.union(insert, delete, substitute).optimize() + + if bound: + sigma_star = pynini.closure(sigma) + self._e_i = sigma_star.copy() + for _ in range(bound): + self._e_i.concat(edit.ques).concat(sigma_star) + else: + self._e_i = edit.union(sigma).closure() + + self._e_i.optimize() + + right_factor_std = EditTransducer._right_factor(self._e_i) + # right_factor_ext allows 0-cost matching between token's raw form & auxiliary form + # e.g.: 'I' -> 'I#', 'AM' -> 'AM#' + right_factor_ext = ( + pynini.union( + *[ + pynini.cross( + pynini.accep(x, token_type=symbol_table), + pynini.accep(x + '#', token_type=symbol_table), + ) + for x in vocab + ] + ) + .optimize() + .closure() + ) + self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize() + + @staticmethod + def _right_factor(ifst: pynini.Fst) -> pynini.Fst: + ofst = pynini.invert(ifst) + syms = pynini.generated_symbols() + insert_label = syms.find(EditTransducer.INSERT) + delete_label = syms.find(EditTransducer.DELETE) + pairs = [(insert_label, delete_label), (delete_label, insert_label)] + right_factor = ofst.relabel_pairs(ipairs=pairs) + return right_factor + + def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst: + lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr) + EditTransducer.check_wellformed_lattice(lattice) + return lattice + + @staticmethod + def check_wellformed_lattice(lattice: pynini.Fst) -> None: + if lattice.start() == pynini.NO_STATE_ID: + raise RuntimeError("Edit distance composition lattice is empty.") + + def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float: + lattice = self.create_lattice(iexpr, oexpr) + # The shortest cost from all final states to the start state is + # equivalent to the cost of the shortest path. + start = lattice.start() + return float(pynini.shortestdistance(lattice, reverse=True)[start]) + + def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike: + print(iexpr) + print(oexpr) + lattice = self.create_lattice(iexpr, oexpr) + alignment = pynini.shortestpath(lattice, nshortest=1, unique=True) + return alignment.optimize() + + +class ErrorStats: + def __init__(self): + self.num_ref_utts = 0 + self.num_hyp_utts = 0 + self.num_eval_utts = 0 # in both ref & hyp + self.num_hyp_without_ref = 0 + + self.C = 0 + self.S = 0 + self.I = 0 + self.D = 0 + self.token_error_rate = 0.0 + self.modified_token_error_rate = 0.0 + + self.num_utts_with_error = 0 + self.sentence_error_rate = 0.0 + + def to_json(self): + # return json.dumps(self.__dict__, indent=4) + return json.dumps(self.__dict__) + + def to_kaldi(self): + info = ( + F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n' + F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n' + ) + return info + + def to_summary(self): + summary = ( + '==================== Overall Statistics ====================\n' + F'num_ref_utts: {self.num_ref_utts}\n' + F'num_hyp_utts: {self.num_hyp_utts}\n' + F'num_hyp_without_ref: {self.num_hyp_without_ref}\n' + F'num_eval_utts: {self.num_eval_utts}\n' + F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n' + F'token_error_rate: {self.token_error_rate:.2f}%\n' + F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n' + F'token_stats:\n' + F' - tokens:{self.C + self.S + self.D:>7}\n' + F' - edits: {self.S + self.I + self.D:>7}\n' + F' - cor: {self.C:>7}\n' + F' - sub: {self.S:>7}\n' + F' - ins: {self.I:>7}\n' + F' - del: {self.D:>7}\n' + '============================================================\n' + ) + return summary + + +class Utterance: + def __init__(self, uid, text): + self.uid = uid + self.text = text + + +def LoadKaldiArc(filepath): + utts = {} + with open(filepath, 'r', encoding='utf8') as f: + for line in f: + line = line.strip() + if line: + cols = line.split(maxsplit=1) + assert len(cols) == 2 or len(cols) == 1 + uid = cols[0] + text = cols[1] if len(cols) == 2 else '' + if utts.get(uid) != None: + raise RuntimeError(F'Found duplicated utterence id {uid}') + utts[uid] = Utterance(uid, text) + return utts + + +def BreakHyphen(token: str): + # 'T-SHIRT' should also introduce new words into vocabulary, e.g.: + # 1. 'T' & 'SHIRT' + # 2. 'TSHIRT' + assert '-' in token + v = token.split('-') + v.append(token.replace('-', '')) + return v + + +def LoadGLM(rel_path): + ''' + glm.csv: + I'VE,I HAVE + GOING TO,GONNA + ... + T-SHIRT,T SHIRT,TSHIRT + + glm: + { + '': ["I'VE", 'I HAVE'], + '': ['GOING TO', 'GONNA'], + ... + '': ['T-SHIRT', 'T SHIRT', 'TSHIRT'], + } + ''' + logging.info(f'Loading GLM from {rel_path} ...') + + abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path + reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=',')) + + glm = {} + for k, rule in enumerate(reader): + rule_name = f'' + glm[rule_name] = [phrase.strip() for phrase in rule] + logging.info(f' #rule: {len(glm)}') + + return glm + + +def SymbolEQ(symbol_table, i1, i2): + return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#') + + +def PrintSymbolTable(symbol_table: pynini.SymbolTable): + print('SYMBOL_TABLE:') + for k in range(symbol_table.num_symbols()): + sym = symbol_table.find(k) + assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym) + print(k, sym) + print() + + +def BuildSymbolTable(vocab) -> pynini.SymbolTable: + logging.info('Building symbol table ...') + symbol_table = pynini.SymbolTable() + symbol_table.add_symbol('') + + for w in vocab: + symbol_table.add_symbol(w) + logging.info(f' #symbols: {symbol_table.num_symbols()}') + + # PrintSymbolTable(symbol_table) + # symbol_table.write_text('symbol_table.txt') + return symbol_table + + +def BuildGLMTagger(glm, symbol_table) -> pynini.Fst: + logging.info('Building GLM tagger ...') + rule_taggers = [] + for rule_tag, rule in glm.items(): + for phrase in rule: + rule_taggers.append( + ( + pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table)) + + pynini.accep(phrase, token_type=symbol_table) + + pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table)) + ) + ) + + alphabet = pynini.union( + *[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon + ).optimize() + + tagger = pynini.cdrewrite( + pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure() + ).optimize() # could be slow with large vocabulary + return tagger + + +def TokenWidth(token: str): + def CharWidth(c): + return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1 + + return sum([CharWidth(c) for c in token]) + + +def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr): + assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali) + + H = ' HYP# : ' + R = ' REF : ' + E = ' EDIT : ' + for i, e in enumerate(edit_ali): + h, r = hyp_ali[i], ref_ali[i] + e = '' if e == 'C' else e # don't bother printing correct edit-tag + + nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e) + n = max(nr, nh, ne) + 1 + + H += h + ' ' * (n - nh) + R += r + ' ' * (n - nr) + E += e + ' ' * (n - ne) + + print(F' HYP : {raw_hyp}', file=stream) + print(H, file=stream) + print(R, file=stream) + print(E, file=stream) + + +def ComputeTokenErrorRate(c, s, i, d): + assert (s + d + c) != 0 + num_edits = s + d + i + ref_len = c + s + d + hyp_len = c + s + i + return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len) + + +def ComputeSentenceErrorRate(num_err_utts, num_utts): + assert num_utts != 0 + return 100.0 * num_err_utts / num_utts + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--logk', type=int, default=500, help='logging interval') + parser.add_argument( + '--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER' + ) + parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm') + parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file') + parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file') + parser.add_argument('result_file', type=str) + args = parser.parse_args() + logging.info(args) + + stats = ErrorStats() + + logging.info('Generating tokenizer ...') + if args.tokenizer == 'whitespace': + + def word_tokenizer(text): + return text.strip().split() + + tokenizer = word_tokenizer + elif args.tokenizer == 'char': + + def char_tokenizer(text): + return [c for c in text.strip().replace(' ', '')] + + tokenizer = char_tokenizer + else: + tokenizer = None + assert tokenizer + + logging.info('Loading REF & HYP ...') + ref_utts = LoadKaldiArc(args.ref) + hyp_utts = LoadKaldiArc(args.hyp) + + # check valid utterances in hyp that have matched non-empty reference + uids = [] + for uid in sorted(hyp_utts.keys()): + if uid in ref_utts.keys(): + if ref_utts[uid].text.strip(): # non-empty reference + uids.append(uid) + else: + logging.warning(F'Found {uid} with empty reference, skipping...') + else: + logging.warning(F'Found {uid} without reference, skipping...') + stats.num_hyp_without_ref += 1 + + stats.num_hyp_utts = len(hyp_utts) + stats.num_ref_utts = len(ref_utts) + stats.num_eval_utts = len(uids) + logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}') + print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}') + + tokens = [] + for uid in uids: + ref_tokens = tokenizer(ref_utts[uid].text) + hyp_tokens = tokenizer(hyp_utts[uid].text) + for t in ref_tokens + hyp_tokens: + tokens.append(t) + if '-' in t: + tokens.extend(BreakHyphen(t)) + vocab_from_utts = list(set(tokens)) + logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}') + print(f' HYP&REF vocab size: {len(vocab_from_utts)}') + + assert args.glm + glm = LoadGLM(args.glm) + + tokens = [] + for rule in glm.values(): + for phrase in rule: + for t in tokenizer(phrase): + tokens.append(t) + if '-' in t: + tokens.extend(BreakHyphen(t)) + vocab_from_glm = list(set(tokens)) + logging.info(f' GLM vocab size: {len(vocab_from_glm)}') + print(f' GLM vocab size: {len(vocab_from_glm)}') + + vocab = list(set(vocab_from_utts + vocab_from_glm)) + logging.info(f'Global vocab size: {len(vocab)}') + print(f'Global vocab size: {len(vocab)}') + + symtab = BuildSymbolTable( + # Normal evaluation vocab + auxiliary form for alternative paths + GLM tags + vocab + + [x + '#' for x in vocab] + + [x for x in glm.keys()] + ) + glm_tagger = BuildGLMTagger(glm, symtab) + edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab) + print(edit_transducer) + + logging.info('Evaluating error rate ...') + print('Evaluating error rate ...') + fo = open(args.result_file, 'w+', encoding='utf8') + ndone = 0 + for uid in uids: + ref = ref_utts[uid].text + raw_hyp = hyp_utts[uid].text + + ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab) + print(ref_fst) + + # print(ref_fst.string(token_type = symtab)) + + raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab) + # print(raw_hyp_fst.string(token_type = symtab)) + + # Say, we have: + # RULE_001: "I'M" <-> "I AM" + # REF: HEY I AM HERE + # HYP: HEY I'M HERE + # + # We want to expand HYP with GLM rules(marked with auxiliary #) + # HYP#: HEY {I'M | I# AM#} HERE + # REF is honored to keep its original form. + # + # This could be considered as a flexible on-the-fly TN towards HYP. + + # 1. GLM rule tagging: + # HEY I'M HERE + # -> + # HEY I'M HERE + lattice = (raw_hyp_fst @ glm_tagger).optimize() + tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab) + # print(hyp_tagged) + + # 2. GLM rule expansion: + # HEY I'M HERE + # -> + # sausage-like fst: HEY {I'M | I# AM#} HERE + tokens = tagged_ir.split() + sausage = pynini.accep('', token_type=symtab) + i = 0 + while i < len(tokens): # invariant: tokens[0, i) has been built into fst + forms = [] + if tokens[i].startswith(''): # rule segment + rule_name = tokens[i] + rule = glm[rule_name] + # pre-condition: i -> ltag + raw_form = '' + for j in range(i + 1, len(tokens)): + if tokens[j] == rule_name: + raw_form = ' '.join(tokens[i + 1 : j]) + break + assert raw_form + # post-condition: i -> ltag, j -> rtag + + forms.append(raw_form) + for phrase in rule: + if phrase != raw_form: + forms.append(' '.join([x + '#' for x in phrase.split()])) + i = j + 1 + else: # normal token segment + token = tokens[i] + forms.append(token) + if "-" in token: # token with hyphen yields extra forms + forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#' + forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#' + i += 1 + + sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize() + sausage += sausage_segment + hyp_fst = sausage.optimize() + print(hyp_fst) + + # Utterance-Level error rate evaluation + alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst) + print("alignment", alignment) + + distance = 0.0 + C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del + edit_ali, ref_ali, hyp_ali = [], [], [] + for state in alignment.states(): + for arc in alignment.arcs(state): + i, o = arc.ilabel, arc.olabel + if i != 0 and o != 0 and SymbolEQ(symtab, i, o): + e = 'C' + r, h = symtab.find(i), symtab.find(o) + + C += 1 + distance += 0.0 + elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o): + e = 'S' + r, h = symtab.find(i), symtab.find(o) + + S += 1 + distance += 1.0 + elif i == 0 and o != 0: + e = 'I' + r, h = '*', symtab.find(o) + + I += 1 + distance += 1.0 + elif i != 0 and o == 0: + e = 'D' + r, h = symtab.find(i), '*' + + D += 1 + distance += 1.0 + else: + raise RuntimeError + + edit_ali.append(e) + ref_ali.append(r) + hyp_ali.append(h) + # assert(distance == edit_transducer.compute_distance(ref_fst, sausage)) + + utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D) + # print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo) + # PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo) + + if utt_ter > 0: + stats.num_utts_with_error += 1 + + stats.C += C + stats.S += S + stats.I += I + stats.D += D + + ndone += 1 + if ndone % args.logk == 0: + logging.info(f'{ndone} utts evaluated.') + logging.info(f'{ndone} utts evaluated in total.') + + # Corpus-Level evaluation + stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D) + stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts) + + print(stats.to_json(), file=fo) + # print(stats.to_kaldi()) + # print(stats.to_summary(), file=fo) + + fo.close() diff --git a/utils/speechio/error_rate_zh.py b/utils/speechio/error_rate_zh.py new file mode 100644 index 0000000..6871a07 --- /dev/null +++ b/utils/speechio/error_rate_zh.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# coding=utf8 + +# Copyright 2021 Jiayu DU + +import sys +import argparse +import json +import logging +logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s') + +DEBUG = None + +def GetEditType(ref_token, hyp_token): + if ref_token == None and hyp_token != None: + return 'I' + elif ref_token != None and hyp_token == None: + return 'D' + elif ref_token == hyp_token: + return 'C' + elif ref_token != hyp_token: + return 'S' + else: + raise RuntimeError + +class AlignmentArc: + def __init__(self, src, dst, ref, hyp): + self.src = src + self.dst = dst + self.ref = ref + self.hyp = hyp + self.edit_type = GetEditType(ref, hyp) + +def similarity_score_function(ref_token, hyp_token): + return 0 if (ref_token == hyp_token) else -1.0 + +def insertion_score_function(token): + return -1.0 + +def deletion_score_function(token): + return -1.0 + +def EditDistance( + ref, + hyp, + similarity_score_function = similarity_score_function, + insertion_score_function = insertion_score_function, + deletion_score_function = deletion_score_function): + assert(len(ref) != 0) + class DPState: + def __init__(self): + self.score = -float('inf') + # backpointer + self.prev_r = None + self.prev_h = None + + def print_search_grid(S, R, H, fstream): + print(file=fstream) + for r in range(R): + for h in range(H): + print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream) + print(file=fstream) + + R = len(ref) + 1 + H = len(hyp) + 1 + + # Construct DP search space, a (R x H) grid + S = [ [] for r in range(R) ] + for r in range(R): + S[r] = [ DPState() for x in range(H) ] + + # initialize DP search grid origin, S(r = 0, h = 0) + S[0][0].score = 0.0 + S[0][0].prev_r = None + S[0][0].prev_h = None + + # initialize REF axis + for r in range(1, R): + S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1]) + S[r][0].prev_r = r-1 + S[r][0].prev_h = 0 + + # initialize HYP axis + for h in range(1, H): + S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1]) + S[0][h].prev_r = 0 + S[0][h].prev_h = h-1 + + best_score = S[0][0].score + best_state = (0, 0) + + for r in range(1, R): + for h in range(1, H): + sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1]) + new_score = S[r-1][h-1].score + sub_or_cor_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r-1 + S[r][h].prev_h = h-1 + + del_score = deletion_score_function(ref[r-1]) + new_score = S[r-1][h].score + del_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r - 1 + S[r][h].prev_h = h + + ins_score = insertion_score_function(hyp[h-1]) + new_score = S[r][h-1].score + ins_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r + S[r][h].prev_h = h-1 + + best_score = S[R-1][H-1].score + best_state = (R-1, H-1) + + if DEBUG: + print_search_grid(S, R, H, sys.stderr) + + # Backtracing best alignment path, i.e. a list of arcs + # arc = (src, dst, ref, hyp, edit_type) + # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis + best_path = [] + r, h = best_state[0], best_state[1] + prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h + score = S[r][h].score + # loop invariant: + # 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path + # 2. score is the value of point(r, h) on DP search grid + while prev_r != None or prev_h != None: + src = (prev_r, prev_h) + dst = (r, h) + if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct + arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h]) + elif (r == prev_r + 1 and h == prev_h): # Deletion + arc = AlignmentArc(src, dst, ref[prev_r], None) + elif (r == prev_r and h == prev_h + 1): # Insertion + arc = AlignmentArc(src, dst, None, hyp[prev_h]) + else: + raise RuntimeError + best_path.append(arc) + r, h = prev_r, prev_h + prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h + score = S[r][h].score + + best_path.reverse() + return (best_path, best_score) + +def PrettyPrintAlignment(alignment, stream = sys.stderr): + def get_token_str(token): + if token == None: + return "*" + return token + + def is_double_width_char(ch): + if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars + return True + # TODO: support other double-width-char language such as Japanese, Korean + else: + return False + + def display_width(token_str): + m = 0 + for c in token_str: + if is_double_width_char(c): + m += 2 + else: + m += 1 + return m + + R = ' REF : ' + H = ' HYP : ' + E = ' EDIT : ' + for arc in alignment: + r = get_token_str(arc.ref) + h = get_token_str(arc.hyp) + e = arc.edit_type if arc.edit_type != 'C' else '' + + nr, nh, ne = display_width(r), display_width(h), display_width(e) + n = max(nr, nh, ne) + 1 + + R += r + ' ' * (n-nr) + H += h + ' ' * (n-nh) + E += e + ' ' * (n-ne) + + print(R, file=stream) + print(H, file=stream) + print(E, file=stream) + +def CountEdits(alignment): + c, s, i, d = 0, 0, 0, 0 + for arc in alignment: + if arc.edit_type == 'C': + c += 1 + elif arc.edit_type == 'S': + s += 1 + elif arc.edit_type == 'I': + i += 1 + elif arc.edit_type == 'D': + d += 1 + else: + raise RuntimeError + return (c, s, i, d) + +def ComputeTokenErrorRate(c, s, i, d): + return 100.0 * (s + d + i) / (s + d + c) + +def ComputeSentenceErrorRate(num_err_utts, num_utts): + assert(num_utts != 0) + return 100.0 * num_err_utts / num_utts + + +class EvaluationResult: + def __init__(self): + self.num_ref_utts = 0 + self.num_hyp_utts = 0 + self.num_eval_utts = 0 # seen in both ref & hyp + self.num_hyp_without_ref = 0 + + self.C = 0 + self.S = 0 + self.I = 0 + self.D = 0 + self.token_error_rate = 0.0 + + self.num_utts_with_error = 0 + self.sentence_error_rate = 0.0 + + def to_json(self): + return json.dumps(self.__dict__) + + def to_kaldi(self): + info = ( + F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n' + F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n' + ) + return info + + def to_sclite(self): + return "TODO" + + def to_espnet(self): + return "TODO" + + def to_summary(self): + #return json.dumps(self.__dict__, indent=4) + summary = ( + '==================== Overall Statistics ====================\n' + F'num_ref_utts: {self.num_ref_utts}\n' + F'num_hyp_utts: {self.num_hyp_utts}\n' + F'num_hyp_without_ref: {self.num_hyp_without_ref}\n' + F'num_eval_utts: {self.num_eval_utts}\n' + F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n' + F'token_error_rate: {self.token_error_rate:.2f}%\n' + F'token_stats:\n' + F' - tokens:{self.C + self.S + self.D:>7}\n' + F' - edits: {self.S + self.I + self.D:>7}\n' + F' - cor: {self.C:>7}\n' + F' - sub: {self.S:>7}\n' + F' - ins: {self.I:>7}\n' + F' - del: {self.D:>7}\n' + '============================================================\n' + ) + return summary + + +class Utterance: + def __init__(self, uid, text): + self.uid = uid + self.text = text + + +def LoadUtterances(filepath, format): + utts = {} + if format == 'text': # utt_id word1 word2 ... + with open(filepath, 'r', encoding='utf8') as f: + for line in f: + line = line.strip() + if line: + cols = line.split(maxsplit=1) + assert(len(cols) == 2 or len(cols) == 1) + uid = cols[0] + text = cols[1] if len(cols) == 2 else '' + if utts.get(uid) != None: + raise RuntimeError(F'Found duplicated utterence id {uid}') + utts[uid] = Utterance(uid, text) + else: + raise RuntimeError(F'Unsupported text format {format}') + return utts + + +def tokenize_text(text, tokenizer): + if tokenizer == 'whitespace': + return text.split() + elif tokenizer == 'char': + return [ ch for ch in ''.join(text.split()) ] + else: + raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # optional + parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER') + parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text') + parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text') + # required + parser.add_argument('--ref', type=str, required=True, help='input reference file') + parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file') + + parser.add_argument('result_file', type=str) + args = parser.parse_args() + logging.info(args) + + ref_utts = LoadUtterances(args.ref, args.ref_format) + hyp_utts = LoadUtterances(args.hyp, args.hyp_format) + + r = EvaluationResult() + + # check valid utterances in hyp that have matched non-empty reference + eval_utts = [] + r.num_hyp_without_ref = 0 + for uid in sorted(hyp_utts.keys()): + if uid in ref_utts.keys(): # TODO: efficiency + if ref_utts[uid].text.strip(): # non-empty reference + eval_utts.append(uid) + else: + logging.warn(F'Found {uid} with empty reference, skipping...') + else: + logging.warn(F'Found {uid} without reference, skipping...') + r.num_hyp_without_ref += 1 + + r.num_hyp_utts = len(hyp_utts) + r.num_ref_utts = len(ref_utts) + r.num_eval_utts = len(eval_utts) + + with open(args.result_file, 'w+', encoding='utf8') as fo: + for uid in eval_utts: + ref = ref_utts[uid] + hyp = hyp_utts[uid] + + alignment, score = EditDistance( + tokenize_text(ref.text, args.tokenizer), + tokenize_text(hyp.text, args.tokenizer) + ) + + c, s, i, d = CountEdits(alignment) + utt_ter = ComputeTokenErrorRate(c, s, i, d) + + # utt-level evaluation result + print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo) + PrettyPrintAlignment(alignment, fo) + + r.C += c + r.S += s + r.I += i + r.D += d + + if utt_ter > 0: + r.num_utts_with_error += 1 + + # corpus level evaluation result + r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts) + r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D) + + print(r.to_summary(), file=fo) + + print(r.to_json()) + print(r.to_kaldi()) diff --git a/utils/speechio/glm_en.csv b/utils/speechio/glm_en.csv new file mode 100644 index 0000000..2bc14f7 --- /dev/null +++ b/utils/speechio/glm_en.csv @@ -0,0 +1,744 @@ +I'M,I AM +I'LL,I WILL +I'D,I HAD +I'VE,I HAVE +I WOULD'VE,I'D HAVE +YOU'RE,YOU ARE +YOU'LL,YOU WILL +YOU'D,YOU WOULD +YOU'VE,YOU HAVE +HE'S,HE IS,HE WAS +HE'LL,HE WILL +HE'D,HE HAD +SHE'S,SHE IS,SHE WAS +SHE'LL,SHE WILL +SHE'D,SHE HAD +IT'S,IT IS,IT WAS +IT'LL,IT WILL +WE'RE,WE ARE,WE WERE +WE'LL,WE WILL +WE'D,WE WOULD +WE'VE,WE HAVE +WHO'LL,WHO WILL +THEY'RE,THEY ARE +THEY'LL,THEY WILL +THAT'S,THAT IS,THAT WAS +THAT'LL,THAT WILL +HERE'S,HERE IS,HERE WAS +THERE'S,THERE IS,THERE WAS +WHERE'S,WHERE IS,WHERE WAS +WHAT'S,WHAT IS,WHAT WAS +LET'S,LET US +WHO'S,WHO IS +ONE'S,ONE IS +THERE'LL,THERE WILL +SOMEBODY'S,SOMEBODY IS +EVERYBODY'S,EVERYBODY IS +WOULD'VE,WOULD HAVE +CAN'T,CANNOT,CAN NOT +HADN'T,HAD NOT +HASN'T,HAS NOT +HAVEN'T,HAVE NOT +ISN'T,IS NOT +AREN'T,ARE NOT +WON'T,WILL NOT +WOULDN'T,WOULD NOT +SHOULDN'T,SHOULD NOT +DON'T,DO NOT +DIDN'T,DID NOT +GOTTA,GOT TO +GONNA,GOING TO +WANNA,WANT TO +LEMME,LET ME +GIMME,GIVE ME +DUNNO,DON'T KNOW +GOTCHA,GOT YOU +KINDA,KIND OF +MYSELF,MY SELF +YOURSELF,YOUR SELF +HIMSELF,HIM SELF +HERSELF,HER SELF +ITSELF,IT SELF +OURSELVES,OUR SELVES +OKAY,OK,O K +Y'ALL,YALL,YOU ALL +'CAUSE,'COS,CUZ,BECAUSE +FUCKIN',FUCKING +KILLING,KILLIN' +EVERYDAY,EVERY DAY +DOCTOR,DR,DR. +MRS,MISSES,MISSUS +MR,MR.,MISTER +SR,SR.,SENIOR +JR,JR.,JUNIOR +ST,ST.,SAINT +VOL,VOL.,VOLUME +CM,CENTIMETER,CENTIMETRE +MM,MILLIMETER,MILLIMETRE +KM,KILOMETER,KILOMETRE +KB,KILOBYTES,KILO BYTES,K B +MB,MEGABYTES,MEGA BYTES +GB,GIGABYTES,GIGA BYTES,G B +THOUSAND,THOUSAND AND +HUNDRED,HUNDRED AND +A HUNDRED,ONE HUNDRED +TWO THOUSAND AND,TWENTY,TWO THOUSAND +STORYTELLER,STORY TELLER +TSHIRT,T SHIRT +TSHIRTS,T SHIRTS +LEUKAEMIA,LEUKEMIA +OESTROGEN,ESTROGEN +ACKNOWLEDGMENT,ACKNOWLEDGEMENT +JUDGMENT,JUDGEMENT +MAMMA,MAMA +DINING,DINNING +FLACK,FLAK +LEARNT,LEARNED +BLONDE,BLOND +JUMPSTART,JUMP START +RIGHTNOW,RIGHT NOW +EVERYONE,EVERY ONE +NAME'S,NAME IS +FAMILY'S,FAMILY IS +COMPANY'S,COMPANY HAS +GRANDKID,GRAND KID +GRANDKIDS,GRAND KIDS +MEALTIMES,MEAL TIMES +ALRIGHT,ALL RIGHT +GROWNUP,GROWN UP +GROWNUPS,GROWN UPS +SCHOOLDAYS,SCHOOL DAYS +SCHOOLCHILDREN,SCHOOL CHILDREN +CASEBOOK,CASE BOOK +HUNGOVER,HUNG OVER +HANDCLAPS,HAND CLAPS +HANDCLAP,HAND CLAP +HEATWAVE,HEAT WAVE +ADDON,ADD ON +ONTO,ON TO +INTO,IN TO +GOTO,GO TO +GUNSHOT,GUN SHOT +MOTHERFUCKER,MOTHER FUCKER +OFTENTIMES,OFTEN TIMES +SARTRE'S,SARTRE IS +NONSTARTER,NON STARTER +NONSTARTERS,NON STARTERS +LONGTIME,LONG TIME +POLICYMAKERS,POLICY MAKERS +ANYMORE,ANY MORE +CANADA'S,CANADA IS +CELLPHONE,CELL PHONE +WORKPLACE,WORK PLACE +UNDERESTIMATING,UNDER ESTIMATING +CYBERSECURITY,CYBER SECURITY +NORTHEAST,NORTH EAST +ANYTIME,ANY TIME +LIVESTREAM,LIVE STREAM +LIVESTREAMS,LIVE STREAMS +WEBCAM,WEB CAM +EMAIL,E MAIL +ECAM,E CAM +VMIX,V MIX +SETUP,SET UP +SMARTPHONE,SMART PHONE +MULTICASTING,MULTI CASTING +CHITCHAT,CHIT CHAT +SEMIFINAL,SEMI FINAL +SEMIFINALS,SEMI FINALS +BBQ,BARBECUE +STORYLINE,STORY LINE +STORYLINES,STORY LINES +BRO,BROTHER +BROS,BROTHERS +OVERPROTECTIIVE,OVER PROTECTIVE +TIMEOUT,TIME OUT +ADVISOR,ADVISER +TIMBERWOLVES,TIMBER WOLVES +WEBPAGE,WEB PAGE +NEWCOMER,NEW COMER +DELMAR,DEL MAR +NETPLAY,NET PLAY +STREETSIDE,STREET SIDE +COLOURED,COLORED +COLOURFUL,COLORFUL +O,ZERO +ETCETERA,ET CETERA +FUNDRAISING,FUND RAISING +RAINFOREST,RAIN FOREST +BREATHTAKING,BREATH TAKING +WIKIPAGE,WIKI PAGE +OVERTIME,OVER TIME +TRAIN'S TRAIN IS +ANYONE,ANY ONE +PHYSIOTHERAPY,PHYSIO THERAPY +ANYBODY,ANY BODY +BOTTLECAPS,BOTTLE CAPS +BOTTLECAP,BOTTLE CAP +STEPFATHER'S,STEP FATHER'S +STEPFATHER,STEP FATHER +WARTIME,WAR TIME +SCREENSHOT,SCREEN SHOT +TIMELINE,TIME LINE +CITY'S,CITY IS +NONPROFIT,NON PROFIT +KPOP,K POP +HOMEBASE,HOME BASE +LIFELONG,LIFE LONG +LAWSUITS,LAW SUITS +MULTIBILLION,MULTI BILLION +ROADMAP,ROAD MAP +GUY'S,GUY IS +CHECKOUT,CHECK OUT +SQUARESPACE,SQUARE SPACE +REDLINING,RED LINING +BASE'S,BASE IS +TAKEAWAY,TAKE AWAY +CANDYLAND,CANDY LAND +ANTISOCIAL,ANTI SOCIAL +CASEWORK,CASE WORK +RIGOR,RIGOUR +ORGANIZATIONS,ORGANISATIONS +ORGANIZATION,ORGANISATION +SIGNPOST,SIGN POST +WWII,WORLD WAR TWO +WINDOWPANE,WINDOW PANE +SUREFIRE,SURE FIRE +MOUNTAINTOP,MOUNTAIN TOP +SALESPERSON,SALES PERSON +NETWORK,NET WORK +MINISERIES,MINI SERIES +EDWARDS'S,EDWARDS IS +INTERSUBJECTIVITY,INTER SUBJECTIVITY +LIBERALISM'S,LIBERALISM IS +TAGLINE,TAG LINE +SHINETHEORY,SHINE THEORY +CALLYOURGIRLFRIEND,CALL YOUR GIRLFRIEND +STARTUP,START UP +BREAKUP,BREAK UP +RADIOTOPIA,RADIO TOPIA +HEARTBREAKING,HEART BREAKING +AUTOIMMUNE,AUTO IMMUNE +SINISE'S,SINISE IS +KICKBACK,KICK BACK +FOGHORN,FOG HORN +BADASS,BAD ASS +POWERAMERICAFORWARD,POWER AMERICA FORWARD +GOOGLE'S,GOOGLE IS +ROLEPLAY,ROLE PLAY +PRICE'S,PRICE IS +STANDOFF,STAND OFF +FOREVER,FOR EVER +GENERAL'S,GENERAL IS +DOG'S,DOG IS +AUDIOBOOK,AUDIO BOOK +ANYWAY,ANY WAY +PIGEONHOLE,PIEGON HOLE +EGGSHELLS,EGG SHELLS +VACCINE'S,VACCINE IS +WORKOUT,WORK OUT +ADMINISTRATOR'S,ADMINISTRATOR IS +FUCKUP,FUCK UP +RUNOFFS,RUN OFFS +COLORWAY,COLOR WAY +WAITLIST,WAIT LIST +HEALTHCARE,HEALTH CARE +TEXTBOOK,TEXT BOOK +CALLBACK,CALL BACK +PARTYGOERS,PARTY GOERS +SOMEDAY,SOME DAY +NIGHTGOWN,NIGHT GOWN +STANDALONG,STAND ALONG +BUSSINESSWOMAN,BUSSINESS WOMAN +STORYTELLING,STORY TELLING +MARKETPLACE,MARKET PLACE +CRATEJOY,CRATE JOY +OUTPERFORMED,OUT PERFORMED +TRUEBOTANICALS,TRUE BOTANICALS +NONFICTION,NON FICTION +SPINOFF,SPIN OFF +MOTHERFUCKING,MOTHER FUCKING +TRACKLIST,TRACK LIST +GODDAMN,GOD DAMN +PORNHUB,PORN HUB +UNDERAGE,UNDER AGE +GOODBYE,GOOD BYE +HARDCORE,HARD CORE +TRUCK'S,TRUCK IS +COUNTERSTEERING,COUNTER STEERING +BUZZWORD,BUZZ WORD +SUBCOMPONENTS,SUB COMPONENTS +MOREOVER,MORE OVER +PICKUP,PICK UP +NEWSLETTER,NEWS LETTER +KEYWORD,KEY WORD +LOGIN,LOG IN +TOOLBOX,TOOL BOX +LINK'S,LINK IS +PRIMIALVIDEO,PRIMAL VIDEO +DOTNET,DOT NET +AIRSTRIKE,AIR STRIKE +HAIRSTYLE,HAIR STYLE +TOWNSFOLK,TOWNS FOLK +GOLDFISH,GOLD FISH +TOM'S,TOM IS +HOMETOWN,HOME TOWN +CORONAVIRUS,CORONA VIRUS +PLAYSTATION,PLAY STATION +TOMORROW,TO MORROW +TIMECONSUMING,TIME CONSUMING +POSTWAR,POST WAR +HANDSON,HANDS ON +SHAKEUP,SHAKE UP +ECOMERS,E COMERS +COFOUNDER,CO FOUNDER +HIGHEND,HIGH END +INPERSON,IN PERSON +GROWNUP,GROWN UP +SELFREGULATION,SELF REGULATION +INDEPTH,IN DEPTH +ALLTIME,ALL TIME +LONGTERM,LONG TERM +SOCALLED,SO CALLED +SELFCONFIDENCE,SELF CONFIDENCE +STANDUP,STAND UP +MINDBOGGLING,MIND BOGGLING +BEINGFOROTHERS,BEING FOR OTHERS +COWROTE,CO WROTE +COSTARRED,CO STARRED +EDITORINCHIEF,EDITOR IN CHIEF +HIGHSPEED,HIGH SPEED +DECISIONMAKING,DECISION MAKING +WELLBEING,WELL BEING +NONTRIVIAL,NON TRIVIAL +PREEXISTING,PRE EXISTING +STATEOWNED,STATE OWNED +PLUGIN,PLUG IN +PROVERSION,PRO VERSION +OPTIN,OPT IN +FOLLOWUP,FOLLOW UP +FOLLOWUPS,FOLLOW UPS +WIFI,WI FI +THIRDPARTY,THIRD PARTY +PROFESSIONALLOOKING,PROFESSIONAL LOOKING +FULLSCREEN,FULL SCREEN +BUILTIN,BUILT IN +MULTISTREAM,MULTI STREAM +LOWCOST,LOW COST +RESTREAM,RE STREAM +GAMECHANGER,GAME CHANGER +WELLDEVELOPED,WELL DEVELOPED +QUARTERINCH,QUARTER INCH +FASTFASHION,FAST FASHION +ECOMMERCE,E COMMERCE +PRIZEWINNING,PRIZE WINNING +NEVERENDING,NEVER ENDING +MINDBLOWING,MIND BLOWING +REALLIFE,REAL LIFE +REOPEN,RE OPEN +ONDEMAND,ON DEMAND +PROBLEMSOLVING,PROBLEM SOLVING +HEAVYHANDED,HEAVY HANDED +OPENENDED,OPEN ENDED +SELFCONTROL,SELF CONTROL +WELLMEANING,WELL MEANING +COHOST,CO HOST +RIGHTSBASED,RIGHTS BASED +HALFBROTHER,HALF BROTHER +FATHERINLAW,FATHER IN LAW +COAUTHOR,CO AUTHOR +REELECTION,RE ELECTION +SELFHELP,SELF HELP +PROLIFE,PRO LIFE +ANTIDUKE,ANTI DUKE +POSTSTRUCTURALIST,POST STRUCTURALIST +COFOUNDED,CO FOUNDED +XRAY,X RAY +ALLAROUND,ALL AROUND +HIGHTECH,HIGH TECH +TMOBILE,T MOBILE +INHOUSE,IN HOUSE +POSTMORTEM,POST MORTEM +LITTLEKNOWN,LITTLE KNOWN +FALSEPOSITIVE,FALSE POSITIVE +ANTIVAXXER,ANTI VAXXER +EMAILS,E MAILS +DRIVETHROUGH,DRIVE THROUGH +DAYTODAY,DAY TO DAY +COSTAR,CO STAR +EBAY,E BAY +KOOLAID,KOOL AID +ANTIDEMOCRATIC,ANTI DEMOCRATIC +MIDDLEAGED,MIDDLE AGED +SHORTLIVED,SHORT LIVED +BESTSELLING,BEST SELLING +TICTACS,TIC TACS +UHHUH,UH HUH +MULTITANK,MULTI TANK +JAWDROPPING,JAW DROPPING +LIVESTREAMING,LIVE STREAMING +HARDWORKING,HARD WORKING +BOTTOMDWELLING,BOTTOM DWELLING +PRESHOW,PRE SHOW +HANDSFREE,HANDS FREE +TRICKORTREATING,TRICK OR TREATING +PRERECORDED,PRE RECORDED +DOGOODERS,DO GOODERS +WIDERANGING,WIDE RANGING +LIFESAVING,LIFE SAVING +SKIREPORT,SKI REPORT +SNOWBASE,SNOW BASE +JAYZ,JAY Z +SPIDERMAN,SPIDER MAN +FREEKICK,FREE KICK +EDWARDSHELAIRE,EDWARDS HELAIRE +SHORTTERM,SHORT TERM +HAVENOTS,HAVE NOTS +SELFINTEREST,SELF INTEREST +SELFINTERESTED,SELF INTERESTED +SELFCOMPASSION,SELF COMPASSION +MACHINELEARNING,MACHINE LEARNING +COAUTHORED,CO AUTHORED +NONGOVERNMENT,NON GOVERNMENT +SUBSAHARAN,SUB SAHARAN +COCHAIR,CO CHAIR +LARGESCALE,LARGE SCALE +VIDEOONDEMAND,VIDEO ON DEMAND +FIRSTCLASS,FIRST CLASS +COFOUNDERS,CO FOUNDERS +COOP,CO OP +PREORDERS,PRE ORDERS +DOUBLEENTRY,DOUBLE ENTRY +SELFCONFIDENT,SELF CONFIDENT +SELFPORTRAIT,SELF PORTRAIT +NONWHITE,NON WHITE +ONBOARD,ON BOARD +HALFLIFE,HALF LIFE +ONCOURT,ON COURT +SCIFI,SCI FI +XMEN,X MEN +DAYLEWIS,DAY LEWIS +LALALAND,LA LA LAND +AWARDWINNING,AWARD WINNING +BOXOFFICE,BOX OFFICE +TRIDACTYLS,TRI DACTYLS +TRIDACTYL,TRI DACTYL +MEDIUMSIZED,MEDIUM SIZED +POSTSECONDARY,POST SECONDARY +FULLTIME,FULL TIME +GOKART,GO KART +OPENAIR,OPEN AIR +WELLKNOWN,WELL KNOWN +ICECREAM,ICE CREAM +EARTHMOON,EARTH MOON +STATEOFTHEART,STATE OF THE ART +BSIDE,B SIDE +EASTWEST,EAST WEST +ALLSTAR,ALL STAR +RUNNERUP,RUNNER UP +HORSEDRAWN,HORSE DRAWN +OPENSOURCE,OPEN SOURCE +PURPOSEBUILT,PURPOSE BUILT +SQUAREFREE,SQUARE FREE +PRESENTDAY,PRESENT DAY +CANADAUNITED,CANADA UNITED +HOTCHPOTCH,HOTCH POTCH +LOWLYING,LOW LYING +RIGHTHANDED,RIGHT HANDED +PEARSHAPED,PEAR SHAPED +BESTKNOWN,BEST KNOWN +FULLLENGTH,FULL LENGTH +YEARROUND,YEAR ROUND +PREELECTION,PRE ELECTION +RERECORD,RE RECORD +MINIALBUM,MINI ALBUM +LONGESTRUNNING,LONGEST RUNNING +ALLIRELAND,ALL IRELAND +NORTHWESTERN,NORTH WESTERN +PARTTIME,PART TIME +NONGOVERNMENTAL,NON GOVERNMENTAL +ONLINE,ON LINE +ONAIR,ON AIR +NORTHSOUTH,NORTH SOUTH +RERELEASED,RE RELEASED +LEFTHANDED,LEFT HANDED +BSIDES,B SIDES +ANGLOSAXON,ANGLO SAXON +SOUTHSOUTHEAST,SOUTH SOUTHEAST +CROSSCOUNTRY,CROSS COUNTRY +REBUILT,RE BUILT +FREEFORM,FREE FORM +SCOOBYDOO,SCOOBY DOO +ATLARGE,AT LARGE +COUNCILMANAGER,COUNCIL MANAGER +LONGRUNNING,LONG RUNNING +PREWAR,PRE WAR +REELECTED,RE ELECTED +HIGHSCHOOL,HIGH SCHOOL +RUNNERSUP,RUNNERS UP +NORTHWEST,NORTH WEST +WEBBASED,WEB BASED +HIGHQUALITY,HIGH QUALITY +RIGHTWING,RIGHT WING +LANEFOX,LANE FOX +PAYPERVIEW,PAY PER VIEW +COPRODUCTION,CO PRODUCTION +NONPARTISAN,NON PARTISAN +FIRSTPERSON,FIRST PERSON +WORLDRENOWNED,WORLD RENOWNED +VICEPRESIDENT,VICE PRESIDENT +PROROMAN,PRO ROMAN +COPRODUCED,CO PRODUCED +LOWPOWER,LOW POWER +SELFESTEEM,SELF ESTEEM +SEMITRANSPARENT,SEMI TRANSPARENT +SECONDINCOMMAND,SECOND IN COMMAND +HIGHRISE,HIGH RISE +COHOSTED,CO HOSTED +AFRICANAMERICAN,AFRICAN AMERICAN +SOUTHWEST,SOUTH WEST +WELLPRESERVED,WELL PRESERVED +FEATURELENGTH,FEATURE LENGTH +HIPHOP,HIP HOP +ALLBIG,ALL BIG +SOUTHEAST,SOUTH EAST +COUNTERATTACK,COUNTER ATTACK +QUARTERFINALS,QUARTER FINALS +STABLEDOOR,STABLE DOOR +DARKEYED,DARK EYED +ALLAMERICAN,ALL AMERICAN +THIRDPERSON,THIRD PERSON +LOWLEVEL,LOW LEVEL +NTERMINAL,N TERMINAL +DRIEDUP,DRIED UP +AFRICANAMERICANS,AFRICAN AMERICANS +ANTIAPARTHEID,ANTI APARTHEID +STOKEONTRENT,STOKE ON TRENT +NORTHNORTHEAST,NORTH NORTHEAST +BRANDNEW,BRAND NEW +RIGHTANGLED,RIGHT ANGLED +GOVERNMENTOWNED,GOVERNMENT OWNED +SONINLAW,SON IN LAW +SUBJECTOBJECTVERB,SUBJECT OBJECT VERB +LEFTARM,LEFT ARM +LONGLIVED,LONG LIVED +REDEYE,RED EYE +TPOSE,T POSE +NIGHTVISION,NIGHT VISION +SOUTHEASTERN,SOUTH EASTERN +WELLRECEIVED,WELL RECEIVED +ALFAYOUM,AL FAYOUM +TIMEBASED,TIME BASED +KETTLEDRUMS,KETTLE DRUMS +BRIGHTEYED,BRIGHT EYED +REDBROWN,RED BROWN +SAMESEX,SAME SEX +PORTDEPAIX,PORT DE PAIX +CLEANUP,CLEAN UP +PERCENT,PERCENT SIGN +TAKEOUT,TAKE OUT +KNOWHOW,KNOW HOW +FISHBONE,FISH BONE +FISHSTICKS,FISH STICKS +PAPERWORK,PAPER WORK +NICKNACKS,NICK NACKS +STREETTALKING,STREET TALKING +NONACADEMIC,NON ACADEMIC +SHELLY,SHELLEY +SHELLY'S,SHELLEY'S +JIMMY,JIMMIE +JIMMY'S,JIMMIE'S +DRUGSTORE,DRUG STORE +THRU,THROUGH +PLAYDATE,PLAY DATE +MICROLIFE,MICRO LIFE +SKILLSET,SKILL SET +SKILLSETS,SKILL SETS +TRADEOFF,TRADE OFF +TRADEOFFS,TRADE OFFS +ONSCREEN,ON SCREEN +PLAYBACK,PLAY BACK +ARTWORK,ART WORK +COWORKER,CO WORDER +COWORKERS,CO WORDERS +SOMETIME,SOME TIME +SOMETIMES,SOME TIMES +CROWDFUNDING,CROWD FUNDING +AM,A.M.,A M +PM,P.M.,P M +TV,T V +MBA,M B A +USA,U S A +US,U S +UK,U K +CEO,C E O +CFO,C F O +COO,C O O +CIO,C I O +FM,F M +GMC,G M C +FSC,F S C +NPD,N P D +APM,A P M +NGO,N G O +TD,T D +LOL,L O L +IPO,I P O +CNBC,C N B C +IPOS,I P OS +CNBC's,C N B C'S +JT,J T +NPR,N P R +NPR'S,N P R'S +MP,M P +IOI,I O I +DW,D W +CNN,C N N +WSM,W S M +ET,E T +IT,I T +RJ,R J +DVD,D V D +DVD'S,D V D'S +HBO,H B O +LA,L A +XC,X C +SUV,S U V +NBA,N B A +NBA'S,N B A'S +ESPN,E S P N +ESPN'S,E S P N'S +ADT,A D T +HD,H D +VIP,V I P +TMZ,T M Z +CBC,C B C +NPO,N P O +BBC,B B C +LA'S,L A'S +TMZ'S,T M Z'S +HIV,H I V +FTC,F T C +EU,E U +PHD,P H D +AI,A I +FHI,F H I +ICML,I C M L +ICLR,I C L R +BMW,B M W +EV,E V +CR,C R +API,A P I +ICO,I C O +LTE,L T E +OBS,O B S +PC,P C +IO,I O +CRM,C R M +RTMP,R T M P +ASMR,A S M R +GG,G G +WWW,W W W +PEI,P E I +JJ,J J +PT,P T +DJ,D J +SD,S D +POW,P.O.W.,P O W +FYI,F Y I +DC,D C,D.C +ABC,A B C +TJ,T J +WMDT,W M D T +WDTN,W D T N +TY,T Y +EJ,E J +CJ,C J +ACL,A C L +UK'S,U K'S +GTV,G T V +MDMA,M D M A +DFW,D F W +WTF,W T F +AJ,A J +MD,M D +PH,P H +ID,I D +SEO,S E O +UTM'S,U T M'S +EC,E C +UFC,U F C +RV,R V +UTM,U T M +CSV,C S V +SMS,S M S +GRB,G R B +GT,G T +LEM,L E M +XR,X R +EDU,E D U +NBC,N B C +EMS,E M S +CDC,C D C +MLK,M L K +IE,I E +OC,O C +HR,H R +MA,M A +DEE,D E E +AP,A P +UFO,U F O +DE,D E +LGBTQ,L G B T Q +PTA,P T A +NHS,N H S +CMA,C M A +MGM,M G M +AKA,A K A +HW,H W +GOP,G O P +GOP'S,G O P'S +FBI,F B I +PRX,P R X +CTO,C T O +URL,U R L +EIN,E I N +MLS,M L S +CSI,C S I +AOC,A O C +CND,C N D +CP,C P +PP,P P +CLI,C L I +PB,P B +FDA,F D A +MRNA,M R N A +PR,P R +VP,V P +DNC,D N C +MSNBC,M S N B C +GQ,G Q +UT,U T +XXI,X X I +HRV,H R V +WHO,W H O +CRO,C R O +DPA,D P A +PPE,P P E +EVA,E V A +BP,B P +GPS,G P S +AR,A R +PJ,P J +MLM,M L M +OLED,O L E D +BO,B O +VE,V E +UN,U N +SLS,S L S +DM,D M +DM'S,D M'S +ASAP,A S A P +ETA,E T A +DOB,D O B +BMW,B M W diff --git a/utils/speechio/interjections_en.csv b/utils/speechio/interjections_en.csv new file mode 100644 index 0000000..1afbd3b --- /dev/null +++ b/utils/speechio/interjections_en.csv @@ -0,0 +1,20 @@ +ach +ah +eee +eh +er +ew +ha +hee +hm +hmm +hmmm +huh +mm +mmm +oof +uh +uhh +um +oh +hum \ No newline at end of file diff --git a/utils/speechio/nemo_text_processing/README.md b/utils/speechio/nemo_text_processing/README.md new file mode 100644 index 0000000..63ea610 --- /dev/null +++ b/utils/speechio/nemo_text_processing/README.md @@ -0,0 +1 @@ +nemo_version from commit:eae1684f7f33c2a18de9ecfa42ec7db93d39e631 diff --git a/utils/speechio/nemo_text_processing/__init__.py b/utils/speechio/nemo_text_processing/__init__.py new file mode 100644 index 0000000..bc443be --- /dev/null +++ b/utils/speechio/nemo_text_processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/utils/speechio/nemo_text_processing/text_normalization/README.md b/utils/speechio/nemo_text_processing/text_normalization/README.md new file mode 100644 index 0000000..d14e4d1 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/README.md @@ -0,0 +1,10 @@ +# Text Normalization + +Text Normalization is part of NeMo's `nemo_text_processing` - a Python package that is installed with the `nemo_toolkit`. +It converts text from written form into its verbalized form, e.g. "123" -> "one hundred twenty three". + +See [NeMo documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/text_normalization/wfst/wfst_text_normalization.html) for details. + +Tutorial with overview of the package capabilities: [Text_(Inverse)_Normalization.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb) + +Tutorial on how to customize the underlying gramamrs: [WFST_Tutorial.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/WFST_Tutorial.ipynb) diff --git a/utils/speechio/nemo_text_processing/text_normalization/__init__.py b/utils/speechio/nemo_text_processing/text_normalization/__init__.py new file mode 100644 index 0000000..bc443be --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/utils/speechio/nemo_text_processing/text_normalization/data_loader_utils.py b/utils/speechio/nemo_text_processing/text_normalization/data_loader_utils.py new file mode 100644 index 0000000..d6713c9 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/data_loader_utils.py @@ -0,0 +1,350 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import re +import string +from collections import defaultdict, namedtuple +from typing import Dict, List, Optional, Set, Tuple +from unicodedata import category + + + +EOS_TYPE = "EOS" +PUNCT_TYPE = "PUNCT" +PLAIN_TYPE = "PLAIN" +Instance = namedtuple('Instance', 'token_type un_normalized normalized') +known_types = [ + "PLAIN", + "DATE", + "CARDINAL", + "LETTERS", + "VERBATIM", + "MEASURE", + "DECIMAL", + "ORDINAL", + "DIGIT", + "MONEY", + "TELEPHONE", + "ELECTRONIC", + "FRACTION", + "TIME", + "ADDRESS", +] + + +def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]: + """ + https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish + Loads text file in the Kaggle Google text normalization file format: \t\t<`self` if trivial class or normalized text> + E.g. + PLAIN Brillantaisia + PLAIN is + PLAIN a + PLAIN genus + PLAIN of + PLAIN plant + PLAIN in + PLAIN family + PLAIN Acanthaceae + PUNCT . sil + + + Args: + file_path: file path to text file + + Returns: flat list of instances + """ + res = [] + with open(file_path, 'r') as fp: + for line in fp: + parts = line.strip().split("\t") + if parts[0] == "": + res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized="")) + else: + l_type, l_token, l_normalized = parts + l_token = l_token.lower() + l_normalized = l_normalized.lower() + + if l_type == PLAIN_TYPE: + res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token)) + elif l_type != PUNCT_TYPE: + res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized)) + return res + + +def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]: + """ + Load given list of text files using the `load_func` function. + + Args: + file_paths: list of file paths + load_func: loading function + + Returns: flat list of instances + """ + res = [] + for file_path in file_paths: + res.extend(load_func(file_path=file_path)) + return res + + +def clean_generic(text: str) -> str: + """ + Cleans text without affecting semiotic classes. + + Args: + text: string + + Returns: cleaned string + """ + text = text.strip() + text = text.lower() + return text + + +def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float: + """ + Evaluates accuracy given predictions and labels. + + Args: + preds: predictions + labels: labels + input: optional, only needed for verbosity + verbose: if true prints [input], golden labels and predictions + + Returns accuracy + """ + acc = 0 + nums = len(preds) + for i in range(nums): + pred_norm = clean_generic(preds[i]) + label_norm = clean_generic(labels[i]) + if pred_norm == label_norm: + acc = acc + 1 + else: + if input: + print(f"inpu: {json.dumps(input[i])}") + print(f"gold: {json.dumps(label_norm)}") + print(f"pred: {json.dumps(pred_norm)}") + return acc / nums + + +def training_data_to_tokens( + data: List[Instance], category: Optional[str] = None +) -> Dict[str, Tuple[List[str], List[str]]]: + """ + Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings + + Args: + data: list of instances + category: optional semiotic class category name + + Returns Dict: token type -> (list of un_normalized strings, list of normalized strings) + """ + result = defaultdict(lambda: ([], [])) + for instance in data: + if instance.token_type != EOS_TYPE: + if category is None or instance.token_type == category: + result[instance.token_type][0].append(instance.un_normalized) + result[instance.token_type][1].append(instance.normalized) + return result + + +def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]: + """ + Takes instance list, creates list of sentences split by EOS_Token + Args: + data: list of instances + Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence) + """ + # split data at EOS boundaries + sentences = [] + sentence = [] + categories = [] + sentence_categories = set() + + for instance in data: + if instance.token_type == EOS_TYPE: + sentences.append(sentence) + sentence = [] + categories.append(sentence_categories) + sentence_categories = set() + else: + sentence.append(instance) + sentence_categories.update([instance.token_type]) + un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences] + normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences] + return un_normalized, normalized, categories + + +def post_process_punctuation(text: str) -> str: + """ + Normalized quotes and spaces + + Args: + text: text + + Returns: text with normalized spaces and quotes + """ + text = ( + text.replace('( ', '(') + .replace(' )', ')') + .replace('{ ', '{') + .replace(' }', '}') + .replace('[ ', '[') + .replace(' ]', ']') + .replace(' ', ' ') + .replace('”', '"') + .replace("’", "'") + .replace("»", '"') + .replace("«", '"') + .replace("\\", "") + .replace("„", '"') + .replace("´", "'") + .replace("’", "'") + .replace('“', '"') + .replace("‘", "'") + .replace('`', "'") + .replace('- -', "--") + ) + + for punct in "!,.:;?": + text = text.replace(f' {punct}', punct) + return text.strip() + + +def pre_process(text: str) -> str: + """ + Optional text preprocessing before normalization (part of TTS TN pipeline) + + Args: + text: string that may include semiotic classes + + Returns: text with spaces around punctuation marks + """ + space_both = '[]' + for punct in space_both: + text = text.replace(punct, ' ' + punct + ' ') + + # remove extra space + text = re.sub(r' +', ' ', text) + return text + + +def load_file(file_path: str) -> List[str]: + """ + Loads given text file with separate lines into list of string. + + Args: + file_path: file path + + Returns: flat list of string + """ + res = [] + with open(file_path, 'r') as fp: + for line in fp: + res.append(line) + return res + + +def write_file(file_path: str, data: List[str]): + """ + Writes out list of string to file. + + Args: + file_path: file path + data: list of string + + """ + with open(file_path, 'w') as fp: + for line in data: + fp.write(line + '\n') + + +def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False): + """ + Post-processing of the normalized output to match input in terms of spaces around punctuation marks. + After NN normalization, Moses detokenization puts a space after + punctuation marks, and attaches an opening quote "'" to the word to the right. + E.g., input to the TN NN model is "12 test' example", + after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote, + but it doesn't match the input and can cause issues during TTS voice generation.) + The current function will match the punctuation and spaces of the normalized text with the input sequence. + "12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input). + + Args: + input: input text (original input to the NN, before normalization or tokenization) + normalized_text: output text (output of the TN NN model) + add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time) + """ + # in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly) + # this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement + # to make sure these new double quotes work with this function + if "``" in input and "``" not in normalized_text: + input = input.replace("``", '"') + input = [x for x in input] + normalized_text = [x for x in normalized_text] + punct_marks = [x for x in string.punctuation if x in input] + + if add_unicode_punct: + punct_unicode = [ + chr(i) + for i in range(sys.maxunicode) + if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input + ] + punct_marks = punct_marks.extend(punct_unicode) + + for punct in punct_marks: + try: + equal = True + if input.count(punct) != normalized_text.count(punct): + equal = False + idx_in, idx_out = 0, 0 + while punct in input[idx_in:]: + idx_out = normalized_text.index(punct, idx_out) + idx_in = input.index(punct, idx_in) + + def _is_valid(idx_out, idx_in, normalized_text, input): + """Check if previous or next word match (for cases when punctuation marks are part of + semiotic token, i.e. some punctuation can be missing in the normalized text)""" + return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or ( + idx_out < len(normalized_text) - 1 + and idx_in < len(input) - 1 + and normalized_text[idx_out + 1] == input[idx_in + 1] + ) + + if not equal and not _is_valid(idx_out, idx_in, normalized_text, input): + idx_in += 1 + continue + if idx_in > 0 and idx_out > 0: + if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ": + normalized_text[idx_out - 1] = "" + + elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ": + normalized_text[idx_out - 1] += " " + + if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1: + if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ": + normalized_text[idx_out + 1] = "" + elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ": + normalized_text[idx_out] = normalized_text[idx_out] + " " + idx_out += 1 + idx_in += 1 + except: + pass + + normalized_text = "".join(normalized_text) + return re.sub(r' +', ' ', normalized_text) diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/__init__.py b/utils/speechio/nemo_text_processing/text_normalization/en/__init__.py new file mode 100644 index 0000000..a9d7d97 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst +from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst +from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/clean_eval_data.py b/utils/speechio/nemo_text_processing/text_normalization/en/clean_eval_data.py new file mode 100644 index 0000000..8c33c4f --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/clean_eval_data.py @@ -0,0 +1,342 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from typing import List + +import regex as re +from nemo_text_processing.text_normalization.data_loader_utils import ( + EOS_TYPE, + Instance, + load_files, + training_data_to_sentences, +) + + +""" +This file is for evaluation purposes. +filter_loaded_data() cleans data (list of instances) for text normalization. Filters and cleaners can be specified for each semiotic class individually. +For example, normalized text should only include characters and whitespace characters but no punctuation. + Cardinal unnormalized instances should contain at least one integer and all other characters are removed. +""" + + +class Filter: + """ + Filter class + + Args: + class_type: semiotic class used in dataset + process_func: function to transform text + filter_func: function to filter text + + """ + + def __init__(self, class_type: str, process_func: object, filter_func: object): + self.class_type = class_type + self.process_func = process_func + self.filter_func = filter_func + + def filter(self, instance: Instance) -> bool: + """ + filter function + + Args: + filters given instance with filter function + + Returns: True if given instance fulfills criteria or does not belong to class type + """ + if instance.token_type != self.class_type: + return True + return self.filter_func(instance) + + def process(self, instance: Instance) -> Instance: + """ + process function + + Args: + processes given instance with process function + + Returns: processed instance if instance belongs to expected class type or original instance + """ + if instance.token_type != self.class_type: + return instance + return self.process_func(instance) + + +def filter_cardinal_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_cardinal_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + un_normalized = re.sub(r"[^0-9]", "", un_normalized) + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_ordinal_1(instance: Instance) -> bool: + ok = re.search(r"(st|nd|rd|th)\s*$", instance.un_normalized) + return ok + + +def process_ordinal_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + un_normalized = re.sub(r"[,\s]", "", un_normalized) + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_decimal_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_decimal_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + un_normalized = re.sub(r",", "", un_normalized) + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_measure_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_measure_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + un_normalized = re.sub(r",", "", un_normalized) + un_normalized = re.sub(r"m2", "m²", un_normalized) + un_normalized = re.sub(r"(\d)([^\d.\s])", r"\1 \2", un_normalized) + normalized = re.sub(r"[^a-z\s]", "", normalized) + normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized) + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_money_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_money_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + un_normalized = re.sub(r",", "", un_normalized) + un_normalized = re.sub(r"a\$", r"$", un_normalized) + un_normalized = re.sub(r"us\$", r"$", un_normalized) + un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized) + un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized) + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_time_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_time_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + un_normalized = re.sub(r": ", ":", un_normalized) + un_normalized = re.sub(r"(\d)\s?a\s?m\s?", r"\1 a.m.", un_normalized) + un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized) + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_plain_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_plain_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_punct_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_punct_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_date_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_date_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + un_normalized = re.sub(r",", "", un_normalized) + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_letters_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_letters_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_verbatim_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_verbatim_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_digit_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_digit_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_telephone_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_telephone_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_electronic_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_electronic_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_fraction_1(instance: Instance) -> bool: + ok = re.search(r"[0-9]", instance.un_normalized) + return ok + + +def process_fraction_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +def filter_address_1(instance: Instance) -> bool: + ok = True + return ok + + +def process_address_1(instance: Instance) -> Instance: + un_normalized = instance.un_normalized + normalized = instance.normalized + normalized = re.sub(r"[^a-z ]", "", normalized) + return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized) + + +filters = [] +filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1)) +filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1)) +filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1)) +filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1)) +filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1)) +filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1)) + +filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1)) +filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1)) +filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1)) +filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1)) +filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1)) +filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1)) +filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1)) +filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1)) +filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1)) +filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1)) +filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True)) + + +def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Instance]: + """ + Filters list of instances + + Args: + data: list of instances + + Returns: filtered and transformed list of instances + """ + updates_instances = [] + for instance in data: + updated_instance = False + for fil in filters: + if fil.class_type == instance.token_type and fil.filter(instance): + instance = fil.process(instance) + updated_instance = True + if updated_instance: + if verbose: + print(instance) + updates_instances.append(instance) + return updates_instances + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100') + parser.add_argument("--verbose", help="print filtered instances", action='store_true') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + file_path = args.input + + print("Loading training data: " + file_path) + instance_list = load_files([file_path]) # List of instances + filtered_instance_list = filter_loaded_data(instance_list, args.verbose) + training_data_to_sentences(filtered_instance_list) diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/graph_utils.py b/utils/speechio/nemo_text_processing/text_normalization/en/graph_utils.py new file mode 100644 index 0000000..6eca6f6 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/graph_utils.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright 2015 and onwards Google, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import string +from pathlib import Path +from typing import Dict + +import pynini +from nemo_text_processing.text_normalization.en.utils import get_abs_path +from pynini import Far +from pynini.examples import plurals +from pynini.export import export +from pynini.lib import byte, pynutil, utf8 + +NEMO_CHAR = utf8.VALID_UTF8_CHAR + +NEMO_DIGIT = byte.DIGIT +NEMO_LOWER = pynini.union(*string.ascii_lowercase).optimize() +NEMO_UPPER = pynini.union(*string.ascii_uppercase).optimize() +NEMO_ALPHA = pynini.union(NEMO_LOWER, NEMO_UPPER).optimize() +NEMO_ALNUM = pynini.union(NEMO_DIGIT, NEMO_ALPHA).optimize() +NEMO_HEX = pynini.union(*string.hexdigits).optimize() +NEMO_NON_BREAKING_SPACE = u"\u00A0" +NEMO_SPACE = " " +NEMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize() +NEMO_NOT_SPACE = pynini.difference(NEMO_CHAR, NEMO_WHITE_SPACE).optimize() +NEMO_NOT_QUOTE = pynini.difference(NEMO_CHAR, r'"').optimize() + +NEMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize() +NEMO_GRAPH = pynini.union(NEMO_ALNUM, NEMO_PUNCT).optimize() + +NEMO_SIGMA = pynini.closure(NEMO_CHAR) + +delete_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE)) +delete_zero_or_one_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE, 0, 1)) +insert_space = pynutil.insert(" ") +delete_extra_space = pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 1), " ") +delete_preserve_order = pynini.closure( + pynutil.delete(" preserve_order: true") + | (pynutil.delete(" field_order: \"") + NEMO_NOT_QUOTE + pynutil.delete("\"")) +) + +suppletive = pynini.string_file(get_abs_path("data/suppletive.tsv")) +# _v = pynini.union("a", "e", "i", "o", "u") +_c = pynini.union( + "b", "c", "d", "f", "g", "h", "j", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "x", "y", "z" +) +_ies = NEMO_SIGMA + _c + pynini.cross("y", "ies") +_es = NEMO_SIGMA + pynini.union("s", "sh", "ch", "x", "z") + pynutil.insert("es") +_s = NEMO_SIGMA + pynutil.insert("s") + +graph_plural = plurals._priority_union( + suppletive, plurals._priority_union(_ies, plurals._priority_union(_es, _s, NEMO_SIGMA), NEMO_SIGMA), NEMO_SIGMA +).optimize() + +SINGULAR_TO_PLURAL = graph_plural +PLURAL_TO_SINGULAR = pynini.invert(graph_plural) +TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)]) +TO_UPPER = pynini.invert(TO_LOWER) +MIN_NEG_WEIGHT = -0.0001 +MIN_POS_WEIGHT = 0.0001 + + +def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']): + """ + Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name. + + Args: + file_name: exported file name + graphs: Mapping of a rule name and Pynini WFST graph to be exported + """ + exporter = export.Exporter(file_name) + for rule, graph in graphs.items(): + exporter[rule] = graph.optimize() + exporter.close() + print(f'Created {file_name}') + + +def get_plurals(fst): + """ + Given singular returns plurals + + Args: + fst: Fst + + Returns plurals to given singular forms + """ + return SINGULAR_TO_PLURAL @ fst + + +def get_singulars(fst): + """ + Given plural returns singulars + + Args: + fst: Fst + + Returns singulars to given plural forms + """ + return PLURAL_TO_SINGULAR @ fst + + +def convert_space(fst) -> 'pynini.FstLike': + """ + Converts space to nonbreaking space. + Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty" + This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it. + + Args: + fst: input fst + + Returns output fst where breaking spaces are converted to non breaking spaces + """ + return fst @ pynini.cdrewrite(pynini.cross(NEMO_SPACE, NEMO_NON_BREAKING_SPACE), "", "", NEMO_SIGMA) + + +class GraphFst: + """ + Base class for all grammar fsts. + + Args: + name: name of grammar class + kind: either 'classify' or 'verbalize' + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, name: str, kind: str, deterministic: bool = True): + self.name = name + self.kind = str + self._fst = None + self.deterministic = deterministic + + self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far') + if self.far_exist(): + self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst() + + def far_exist(self) -> bool: + """ + Returns true if FAR can be loaded + """ + return self.far_path.exists() + + @property + def fst(self) -> 'pynini.FstLike': + return self._fst + + @fst.setter + def fst(self, fst): + self._fst = fst + + def add_tokens(self, fst) -> 'pynini.FstLike': + """ + Wraps class name around to given fst + + Args: + fst: input fst + + Returns: + Fst: fst + """ + return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }") + + def delete_tokens(self, fst) -> 'pynini.FstLike': + """ + Deletes class name wrap around output of given fst + + Args: + fst: input fst + + Returns: + Fst: fst + """ + res = ( + pynutil.delete(f"{self.name}") + + delete_space + + pynutil.delete("{") + + delete_space + + fst + + delete_space + + pynutil.delete("}") + ) + return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA) diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/__init__.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/__init__.py new file mode 100644 index 0000000..bc443be --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/abbreviation.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/abbreviation.py new file mode 100644 index 0000000..640bb48 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/abbreviation.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_UPPER, GraphFst, insert_space +from pynini.lib import pynutil + + +class AbbreviationFst(GraphFst): + """ + Finite state transducer for classifying electronic: as URLs, email addresses, etc. + e.g. "ABC" -> tokens { abbreviation { value: "A B C" } } + + Args: + whitelist: whitelist FST + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, whitelist: 'pynini.FstLike', deterministic: bool = True): + super().__init__(name="abbreviation", kind="classify", deterministic=deterministic) + + dot = pynini.accep(".") + # A.B.C. -> A. B. C. + graph = NEMO_UPPER + dot + pynini.closure(insert_space + NEMO_UPPER + dot, 1) + # A.B.C. -> A.B.C. + graph |= NEMO_UPPER + dot + pynini.closure(NEMO_UPPER + dot, 1) + # ABC -> A B C + graph |= NEMO_UPPER + pynini.closure(insert_space + NEMO_UPPER, 1) + + # exclude words that are included in the whitelist + graph = pynini.compose( + pynini.difference(pynini.project(graph, "input"), pynini.project(whitelist.graph, "input")), graph + ) + + graph = pynutil.insert("value: \"") + graph.optimize() + pynutil.insert("\"") + graph = self.add_tokens(graph) + self.fst = graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/cardinal.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/cardinal.py new file mode 100644 index 0000000..9b94143 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/cardinal.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_DIGIT, + NEMO_NOT_QUOTE, + NEMO_SIGMA, + GraphFst, + insert_space, +) +from nemo_text_processing.text_normalization.en.taggers.date import get_four_digit_year_graph +from nemo_text_processing.text_normalization.en.utils import get_abs_path +from pynini.examples import plurals +from pynini.lib import pynutil + + +class CardinalFst(GraphFst): + """ + Finite state transducer for classifying cardinals, e.g. + -23 -> cardinal { negative: "true" integer: "twenty three" } } + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True, lm: bool = False): + super().__init__(name="cardinal", kind="classify", deterministic=deterministic) + + self.lm = lm + self.deterministic = deterministic + # TODO replace to have "oh" as a default for "0" + graph = pynini.Far(get_abs_path("data/number/cardinal_number_name.far")).get_fst() + self.graph_hundred_component_at_least_one_none_zero_digit = ( + pynini.closure(NEMO_DIGIT, 2, 3) | pynini.difference(NEMO_DIGIT, pynini.accep("0")) + ) @ graph + + graph_digit = pynini.string_file(get_abs_path("data/number/digit.tsv")) + graph_zero = pynini.string_file(get_abs_path("data/number/zero.tsv")) + + single_digits_graph = pynini.invert(graph_digit | graph_zero) + self.single_digits_graph = single_digits_graph + pynini.closure(insert_space + single_digits_graph) + + if not deterministic: + # for a single token allow only the same normalization + # "007" -> {"oh oh seven", "zero zero seven"} not {"oh zero seven"} + single_digits_graph_zero = pynini.invert(graph_digit | graph_zero) + single_digits_graph_oh = pynini.invert(graph_digit) | pynini.cross("0", "oh") + + self.single_digits_graph = single_digits_graph_zero + pynini.closure( + insert_space + single_digits_graph_zero + ) + self.single_digits_graph |= single_digits_graph_oh + pynini.closure(insert_space + single_digits_graph_oh) + + single_digits_graph_with_commas = pynini.closure( + self.single_digits_graph + insert_space, 1, 3 + ) + pynini.closure( + pynutil.delete(",") + + single_digits_graph + + insert_space + + single_digits_graph + + insert_space + + single_digits_graph, + 1, + ) + + optional_minus_graph = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1) + + graph = ( + pynini.closure(NEMO_DIGIT, 1, 3) + + (pynini.closure(pynutil.delete(",") + NEMO_DIGIT ** 3) | pynini.closure(NEMO_DIGIT ** 3)) + ) @ graph + + self.graph = graph + self.graph_with_and = self.add_optional_and(graph) + + if deterministic: + long_numbers = pynini.compose(NEMO_DIGIT ** (5, ...), self.single_digits_graph).optimize() + final_graph = plurals._priority_union(long_numbers, self.graph_with_and, NEMO_SIGMA).optimize() + cardinal_with_leading_zeros = pynini.compose( + pynini.accep("0") + pynini.closure(NEMO_DIGIT), self.single_digits_graph + ) + final_graph |= cardinal_with_leading_zeros + else: + leading_zeros = pynini.compose(pynini.closure(pynini.accep("0"), 1), self.single_digits_graph) + cardinal_with_leading_zeros = ( + leading_zeros + pynutil.insert(" ") + pynini.compose(pynini.closure(NEMO_DIGIT), self.graph_with_and) + ) + + # add small weight to non-default graphs to make sure the deterministic option is listed first + final_graph = ( + self.graph_with_and + | pynutil.add_weight(self.single_digits_graph, 0.0001) + | get_four_digit_year_graph() # allows e.g. 4567 be pronouced as forty five sixty seven + | pynutil.add_weight(single_digits_graph_with_commas, 0.0001) + | cardinal_with_leading_zeros + ) + + final_graph = optional_minus_graph + pynutil.insert("integer: \"") + final_graph + pynutil.insert("\"") + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() + + def add_optional_and(self, graph): + graph_with_and = graph + + if not self.lm: + graph_with_and = pynutil.add_weight(graph, 0.00001) + not_quote = pynini.closure(NEMO_NOT_QUOTE) + no_thousand_million = pynini.difference( + not_quote, not_quote + pynini.union("thousand", "million") + not_quote + ).optimize() + integer = ( + not_quote + pynutil.add_weight(pynini.cross("hundred ", "hundred and ") + no_thousand_million, -0.0001) + ).optimize() + + no_hundred = pynini.difference(NEMO_SIGMA, not_quote + pynini.accep("hundred") + not_quote).optimize() + integer |= ( + not_quote + pynutil.add_weight(pynini.cross("thousand ", "thousand and ") + no_hundred, -0.0001) + ).optimize() + + optional_hundred = pynini.compose((NEMO_DIGIT - "0") ** 3, graph).optimize() + optional_hundred = pynini.compose(optional_hundred, NEMO_SIGMA + pynini.cross(" hundred", "") + NEMO_SIGMA) + graph_with_and |= pynini.compose(graph, integer).optimize() + graph_with_and |= optional_hundred + return graph_with_and diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/date.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/date.py new file mode 100644 index 0000000..2a580a8 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/date.py @@ -0,0 +1,370 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_CHAR, + NEMO_DIGIT, + NEMO_LOWER, + NEMO_SIGMA, + NEMO_NOT_QUOTE, + TO_LOWER, + GraphFst, + delete_extra_space, + delete_space, + insert_space, +) +from nemo_text_processing.text_normalization.en.utils import ( + augment_labels_with_punct_at_end, + get_abs_path, + load_labels, +) +from pynini.examples import plurals +from pynini.lib import pynutil + +graph_teen = pynini.invert(pynini.string_file(get_abs_path("data/number/teen.tsv"))).optimize() +graph_digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize() +ties_graph = pynini.invert(pynini.string_file(get_abs_path("data/number/ty.tsv"))).optimize() +year_suffix = load_labels(get_abs_path("data/date/year_suffix.tsv")) +year_suffix.extend(augment_labels_with_punct_at_end(year_suffix)) +year_suffix = pynini.string_map(year_suffix).optimize() + + +def get_ties_graph(deterministic: bool = True): + """ + Returns two digit transducer, e.g. + 03 -> o three + 12 -> thirteen + 20 -> twenty + """ + graph = graph_teen | ties_graph + pynutil.delete("0") | ties_graph + insert_space + graph_digit + + if deterministic: + graph = graph | pynini.cross("0", "o") + insert_space + graph_digit + else: + graph = graph | (pynini.cross("0", "o") | pynini.cross("0", "zero")) + insert_space + graph_digit + + return graph.optimize() + + +def get_four_digit_year_graph(deterministic: bool = True): + """ + Returns a four digit transducer which is combination of ties/teen or digits + (using hundred instead of thousand format), e.g. + 1219 -> twelve nineteen + 3900 -> thirty nine hundred + """ + graph_ties = get_ties_graph(deterministic) + + graph_with_s = ( + (graph_ties + insert_space + graph_ties) + | (graph_teen + insert_space + (ties_graph | pynini.cross("1", "ten"))) + ) + pynutil.delete("0s") + + graph_with_s |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred") + pynutil.delete("s") + graph_with_s = graph_with_s @ pynini.cdrewrite( + pynini.cross("y", "ies") | pynutil.insert("s"), "", "[EOS]", NEMO_SIGMA + ) + + graph = graph_ties + insert_space + graph_ties + graph |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred") + + thousand_graph = ( + graph_digit + + insert_space + + pynini.cross("00", "thousand") + + (pynutil.delete("0") | insert_space + graph_digit) + ) + thousand_graph |= ( + graph_digit + + insert_space + + pynini.cross("000", "thousand") + + pynini.closure(pynutil.delete(" "), 0, 1) + + pynini.accep("s") + ) + + graph |= graph_with_s + if deterministic: + graph = plurals._priority_union(thousand_graph, graph, NEMO_SIGMA) + else: + graph |= thousand_graph + + return graph.optimize() + + +def _get_two_digit_year_with_s_graph(): + # to handle '70s -> seventies + graph = ( + pynini.closure(pynutil.delete("'"), 0, 1) + + pynini.compose( + ties_graph + pynutil.delete("0s"), pynini.cdrewrite(pynini.cross("y", "ies"), "", "[EOS]", NEMO_SIGMA) + ) + ).optimize() + return graph + + +def _get_year_graph(cardinal_graph, deterministic: bool = True): + """ + Transducer for year, only from 1000 - 2999 e.g. + 1290 -> twelve nineteen + 2000 - 2009 will be verbalized as two thousand. + + Transducer for 3 digit year, e.g. 123-> one twenty three + + Transducer for year with suffix + 123 A.D., 4200 B.C + """ + graph = get_four_digit_year_graph(deterministic) + graph = (pynini.union("1", "2") + (NEMO_DIGIT ** 3) + pynini.closure(pynini.cross(" s", "s") | "s", 0, 1)) @ graph + + graph |= _get_two_digit_year_with_s_graph() + + three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph + year_with_suffix = ( + (get_four_digit_year_graph(deterministic=True) | three_digit_year) + delete_space + insert_space + year_suffix + ) + graph |= year_with_suffix + return graph.optimize() + + +def _get_two_digit_year(cardinal_graph, single_digits_graph): + wo_digit_year = NEMO_DIGIT ** (2) @ plurals._priority_union(cardinal_graph, single_digits_graph, NEMO_SIGMA) + return wo_digit_year + + +class DateFst(GraphFst): + """ + Finite state transducer for classifying date, e.g. + jan. 5, 2012 -> date { month: "january" day: "five" year: "twenty twelve" preserve_order: true } + jan. 5 -> date { month: "january" day: "five" preserve_order: true } + 5 january 2012 -> date { day: "five" month: "january" year: "twenty twelve" preserve_order: true } + 2012-01-05 -> date { year: "twenty twelve" month: "january" day: "five" } + 2012.01.05 -> date { year: "twenty twelve" month: "january" day: "five" } + 2012/01/05 -> date { year: "twenty twelve" month: "january" day: "five" } + 2012 -> date { year: "twenty twelve" } + + Args: + cardinal: CardinalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal: GraphFst, deterministic: bool, lm: bool = False): + super().__init__(name="date", kind="classify", deterministic=deterministic) + + # january + month_graph = pynini.string_file(get_abs_path("data/date/month_name.tsv")).optimize() + # January, JANUARY + month_graph |= pynini.compose(TO_LOWER + pynini.closure(NEMO_CHAR), month_graph) | pynini.compose( + TO_LOWER ** (2, ...), month_graph + ) + + # jan + month_abbr_graph = pynini.string_file(get_abs_path("data/date/month_abbr.tsv")).optimize() + # jan, Jan, JAN + month_abbr_graph = ( + month_abbr_graph + | pynini.compose(TO_LOWER + pynini.closure(NEMO_LOWER, 1), month_abbr_graph).optimize() + | pynini.compose(TO_LOWER ** (2, ...), month_abbr_graph).optimize() + ) + pynini.closure(pynutil.delete("."), 0, 1) + month_graph |= month_abbr_graph.optimize() + + month_numbers_labels = pynini.string_file(get_abs_path("data/date/month_number.tsv")).optimize() + cardinal_graph = cardinal.graph_hundred_component_at_least_one_none_zero_digit + + year_graph = _get_year_graph(cardinal_graph=cardinal_graph, deterministic=deterministic) + + # three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph + # year_graph |= three_digit_year + + month_graph = pynutil.insert("month: \"") + month_graph + pynutil.insert("\"") + month_numbers_graph = pynutil.insert("month: \"") + month_numbers_labels + pynutil.insert("\"") + + endings = ["rd", "th", "st", "nd"] + endings += [x.upper() for x in endings] + endings = pynini.union(*endings) + + day_graph = ( + pynutil.insert("day: \"") + + pynini.closure(pynutil.delete("the "), 0, 1) + + ( + ((pynini.union("1", "2") + NEMO_DIGIT) | NEMO_DIGIT | (pynini.accep("3") + pynini.union("0", "1"))) + + pynini.closure(pynutil.delete(endings), 0, 1) + ) + @ cardinal_graph + + pynutil.insert("\"") + ) + + two_digit_year = _get_two_digit_year( + cardinal_graph=cardinal_graph, single_digits_graph=cardinal.single_digits_graph + ) + two_digit_year = pynutil.insert("year: \"") + two_digit_year + pynutil.insert("\"") + + # if lm: + # two_digit_year = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (3), two_digit_year) + # year_graph = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (2), year_graph) + # year_graph |= pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (4, ...), year_graph) + + graph_year = pynutil.insert(" year: \"") + pynutil.delete(" ") + year_graph + pynutil.insert("\"") + graph_year |= ( + pynutil.insert(" year: \"") + + pynini.accep(",") + + pynini.closure(pynini.accep(" "), 0, 1) + + year_graph + + pynutil.insert("\"") + ) + optional_graph_year = pynini.closure(graph_year, 0, 1) + + year_graph = pynutil.insert("year: \"") + year_graph + pynutil.insert("\"") + + graph_mdy = month_graph + ( + (delete_extra_space + day_graph) + | (pynini.accep(" ") + day_graph) + | graph_year + | (delete_extra_space + day_graph + graph_year) + ) + + graph_mdy |= ( + month_graph + + pynini.cross("-", " ") + + day_graph + + pynini.closure(((pynini.cross("-", " ") + NEMO_SIGMA) @ graph_year), 0, 1) + ) + + for x in ["-", "/", "."]: + delete_sep = pynutil.delete(x) + graph_mdy |= ( + month_numbers_graph + + delete_sep + + insert_space + + pynini.closure(pynutil.delete("0"), 0, 1) + + day_graph + + delete_sep + + insert_space + + (year_graph | two_digit_year) + ) + + graph_dmy = day_graph + delete_extra_space + month_graph + optional_graph_year + day_ex_month = (NEMO_DIGIT ** 2 - pynini.project(month_numbers_graph, "input")) @ day_graph + for x in ["-", "/", "."]: + delete_sep = pynutil.delete(x) + graph_dmy |= ( + day_ex_month + + delete_sep + + insert_space + + month_numbers_graph + + delete_sep + + insert_space + + (year_graph | two_digit_year) + ) + + graph_ymd = pynini.accep("") + for x in ["-", "/", "."]: + delete_sep = pynutil.delete(x) + graph_ymd |= ( + (year_graph | two_digit_year) + + delete_sep + + insert_space + + month_numbers_graph + + delete_sep + + insert_space + + pynini.closure(pynutil.delete("0"), 0, 1) + + day_graph + ) + + final_graph = graph_mdy | graph_dmy + + if not deterministic or lm: + final_graph += pynini.closure(pynutil.insert(" preserve_order: true"), 0, 1) + m_sep_d = ( + month_numbers_graph + + pynutil.delete(pynini.union("-", "/")) + + insert_space + + pynini.closure(pynutil.delete("0"), 0, 1) + + day_graph + ) + final_graph |= m_sep_d + else: + final_graph += pynutil.insert(" preserve_order: true") + + final_graph |= graph_ymd | year_graph + + if not deterministic or lm: + ymd_to_mdy_graph = None + ymd_to_dmy_graph = None + mdy_to_dmy_graph = None + md_to_dm_graph = None + + for month in [x[0] for x in load_labels(get_abs_path("data/date/month_name.tsv"))]: + for day in [x[0] for x in load_labels(get_abs_path("data/date/day.tsv"))]: + ymd_to_mdy_curr = ( + pynutil.insert("month: \"" + month + "\" day: \"" + day + "\" ") + + pynini.accep('year:') + + NEMO_SIGMA + + pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"") + ) + + # YY-MM-DD -> MM-DD-YY + ymd_to_mdy_curr = pynini.compose(graph_ymd, ymd_to_mdy_curr) + ymd_to_mdy_graph = ( + ymd_to_mdy_curr + if ymd_to_mdy_graph is None + else pynini.union(ymd_to_mdy_curr, ymd_to_mdy_graph) + ) + + ymd_to_dmy_curr = ( + pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ") + + pynini.accep('year:') + + NEMO_SIGMA + + pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"") + ) + + # YY-MM-DD -> MM-DD-YY + ymd_to_dmy_curr = pynini.compose(graph_ymd, ymd_to_dmy_curr).optimize() + ymd_to_dmy_graph = ( + ymd_to_dmy_curr + if ymd_to_dmy_graph is None + else pynini.union(ymd_to_dmy_curr, ymd_to_dmy_graph) + ) + + mdy_to_dmy_curr = ( + pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ") + + pynutil.delete("month: \"" + month + "\" day: \"" + day + "\" ") + + pynini.accep('year:') + + NEMO_SIGMA + ).optimize() + # MM-DD-YY -> verbalize as MM-DD-YY (February fourth 1991) or DD-MM-YY (the fourth of February 1991) + mdy_to_dmy_curr = pynini.compose(graph_mdy, mdy_to_dmy_curr).optimize() + mdy_to_dmy_graph = ( + mdy_to_dmy_curr + if mdy_to_dmy_graph is None + else pynini.union(mdy_to_dmy_curr, mdy_to_dmy_graph).optimize() + ).optimize() + + md_to_dm_curr = pynutil.insert("day: \"" + day + "\" month: \"" + month + "\"") + pynutil.delete( + "month: \"" + month + "\" day: \"" + day + "\"" + ) + md_to_dm_curr = pynini.compose(m_sep_d, md_to_dm_curr).optimize() + + md_to_dm_graph = ( + md_to_dm_curr + if md_to_dm_graph is None + else pynini.union(md_to_dm_curr, md_to_dm_graph).optimize() + ).optimize() + + final_graph |= mdy_to_dmy_graph | md_to_dm_graph | ymd_to_mdy_graph | ymd_to_dmy_graph + + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() + diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/decimal.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/decimal.py new file mode 100644 index 0000000..2486b5f --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/decimal.py @@ -0,0 +1,129 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_SIGMA, TO_UPPER, GraphFst, get_abs_path +from pynini.lib import pynutil + +delete_space = pynutil.delete(" ") +quantities = pynini.string_file(get_abs_path("data/number/thousand.tsv")) +quantities_abbr = pynini.string_file(get_abs_path("data/number/quantity_abbr.tsv")) +quantities_abbr |= TO_UPPER @ quantities_abbr + + +def get_quantity( + decimal: 'pynini.FstLike', cardinal_up_to_hundred: 'pynini.FstLike', include_abbr: bool +) -> 'pynini.FstLike': + """ + Returns FST that transforms either a cardinal or decimal followed by a quantity into a numeral, + e.g. 1 million -> integer_part: "one" quantity: "million" + e.g. 1.5 million -> integer_part: "one" fractional_part: "five" quantity: "million" + + Args: + decimal: decimal FST + cardinal_up_to_hundred: cardinal FST + """ + quantity_wo_thousand = pynini.project(quantities, "input") - pynini.union("k", "K", "thousand") + if include_abbr: + quantity_wo_thousand |= pynini.project(quantities_abbr, "input") - pynini.union("k", "K", "thousand") + res = ( + pynutil.insert("integer_part: \"") + + cardinal_up_to_hundred + + pynutil.insert("\"") + + pynini.closure(pynutil.delete(" "), 0, 1) + + pynutil.insert(" quantity: \"") + + (quantity_wo_thousand @ (quantities | quantities_abbr)) + + pynutil.insert("\"") + ) + if include_abbr: + quantity = quantities | quantities_abbr + else: + quantity = quantities + res |= ( + decimal + + pynini.closure(pynutil.delete(" "), 0, 1) + + pynutil.insert("quantity: \"") + + quantity + + pynutil.insert("\"") + ) + return res + + +class DecimalFst(GraphFst): + """ + Finite state transducer for classifying decimal, e.g. + -12.5006 billion -> decimal { negative: "true" integer_part: "12" fractional_part: "five o o six" quantity: "billion" } + 1 billion -> decimal { integer_part: "one" quantity: "billion" } + + cardinal: CardinalFst + """ + + def __init__(self, cardinal: GraphFst, deterministic: bool): + super().__init__(name="decimal", kind="classify", deterministic=deterministic) + + cardinal_graph = cardinal.graph_with_and + cardinal_graph_hundred_component_at_least_one_none_zero_digit = ( + cardinal.graph_hundred_component_at_least_one_none_zero_digit + ) + + self.graph = cardinal.single_digits_graph.optimize() + + if not deterministic: + self.graph = self.graph | cardinal_graph + + point = pynutil.delete(".") + optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1) + + self.graph_fractional = pynutil.insert("fractional_part: \"") + self.graph + pynutil.insert("\"") + self.graph_integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"") + final_graph_wo_sign = ( + pynini.closure(self.graph_integer + pynutil.insert(" "), 0, 1) + + point + + pynutil.insert(" ") + + self.graph_fractional + ) + + quantity_w_abbr = get_quantity( + final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=True + ) + quantity_wo_abbr = get_quantity( + final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=False + ) + self.final_graph_wo_negative_w_abbr = final_graph_wo_sign | quantity_w_abbr + self.final_graph_wo_negative = final_graph_wo_sign | quantity_wo_abbr + + # reduce options for non_deterministic and allow either "oh" or "zero", but not combination + if not deterministic: + no_oh_zero = pynini.difference( + NEMO_SIGMA, + (NEMO_SIGMA + "oh" + NEMO_SIGMA + "zero" + NEMO_SIGMA) + | (NEMO_SIGMA + "zero" + NEMO_SIGMA + "oh" + NEMO_SIGMA), + ).optimize() + no_zero_oh = pynini.difference( + NEMO_SIGMA, NEMO_SIGMA + pynini.accep("zero") + NEMO_SIGMA + pynini.accep("oh") + NEMO_SIGMA + ).optimize() + + self.final_graph_wo_negative |= pynini.compose( + self.final_graph_wo_negative, + pynini.cdrewrite( + pynini.cross("integer_part: \"zero\"", "integer_part: \"oh\""), NEMO_SIGMA, NEMO_SIGMA, NEMO_SIGMA + ), + ) + self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_oh_zero).optimize() + self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_zero_oh).optimize() + + final_graph = optional_graph_negative + self.final_graph_wo_negative + + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/electronic.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/electronic.py new file mode 100644 index 0000000..243c065 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/electronic.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_SIGMA, + GraphFst, + get_abs_path, + insert_space, +) +from pynini.lib import pynutil + + +class ElectronicFst(GraphFst): + """ + Finite state transducer for classifying electronic: as URLs, email addresses, etc. + e.g. cdf1@abc.edu -> tokens { electronic { username: "cdf1" domain: "abc.edu" } } + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="electronic", kind="classify", deterministic=deterministic) + + accepted_symbols = pynini.project(pynini.string_file(get_abs_path("data/electronic/symbol.tsv")), "input") + accepted_common_domains = pynini.project( + pynini.string_file(get_abs_path("data/electronic/domain.tsv")), "input" + ) + all_accepted_symbols = NEMO_ALPHA + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols) + graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize() + + username = ( + pynutil.insert("username: \"") + all_accepted_symbols + pynutil.insert("\"") + pynini.cross('@', ' ') + ) + domain_graph = all_accepted_symbols + pynini.accep('.') + all_accepted_symbols + NEMO_ALPHA + protocol_symbols = pynini.closure((graph_symbols | pynini.cross(":", "semicolon")) + pynutil.insert(" ")) + protocol_start = (pynini.cross("https", "HTTPS ") | pynini.cross("http", "HTTP ")) + ( + pynini.accep("://") @ protocol_symbols + ) + protocol_file_start = pynini.accep("file") + insert_space + (pynini.accep(":///") @ protocol_symbols) + + protocol_end = pynini.cross("www", "WWW ") + pynini.accep(".") @ protocol_symbols + protocol = protocol_file_start | protocol_start | protocol_end | (protocol_start + protocol_end) + + domain_graph = ( + pynutil.insert("domain: \"") + + pynini.difference(domain_graph, pynini.project(protocol, "input") + NEMO_SIGMA) + + pynutil.insert("\"") + ) + domain_common_graph = ( + pynutil.insert("domain: \"") + + pynini.difference( + all_accepted_symbols + + accepted_common_domains + + pynini.closure(accepted_symbols + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols), 0, 1), + pynini.project(protocol, "input") + NEMO_SIGMA, + ) + + pynutil.insert("\"") + ) + + protocol = pynutil.insert("protocol: \"") + protocol + pynutil.insert("\"") + # email + graph = username + domain_graph + # abc.com, abc.com/123-sm + graph |= domain_common_graph + # www.abc.com/sdafsdf, or https://www.abc.com/asdfad or www.abc.abc/asdfad + graph |= protocol + pynutil.insert(" ") + domain_graph + + final_graph = self.add_tokens(graph) + + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/fraction.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/fraction.py new file mode 100644 index 0000000..ac6877c --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/fraction.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import GraphFst, get_abs_path +from pynini.lib import pynutil + + +class FractionFst(GraphFst): + """ + Finite state transducer for classifying fraction + "23 4/5" -> + tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } } + "23 4/5th" -> + tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } } + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal, deterministic: bool = True): + super().__init__(name="fraction", kind="classify", deterministic=deterministic) + cardinal_graph = cardinal.graph + + integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"") + numerator = ( + pynutil.insert("numerator: \"") + cardinal_graph + (pynini.cross("/", "\" ") | pynini.cross(" / ", "\" ")) + ) + + endings = ["rd", "th", "st", "nd"] + endings += [x.upper() for x in endings] + optional_end = pynini.closure(pynini.cross(pynini.union(*endings), ""), 0, 1) + + denominator = pynutil.insert("denominator: \"") + cardinal_graph + optional_end + pynutil.insert("\"") + + graph = pynini.closure(integer + pynini.accep(" "), 0, 1) + (numerator + denominator) + graph |= pynini.closure(integer + (pynini.accep(" ") | pynutil.insert(" ")), 0, 1) + pynini.compose( + pynini.string_file(get_abs_path("data/number/fraction.tsv")), (numerator + denominator) + ) + + self.graph = graph + final_graph = self.add_tokens(self.graph) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/measure.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/measure.py new file mode 100644 index 0000000..3861f91 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/measure.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_NON_BREAKING_SPACE, + NEMO_SIGMA, + NEMO_SPACE, + NEMO_UPPER, + SINGULAR_TO_PLURAL, + TO_LOWER, + GraphFst, + convert_space, + delete_space, + delete_zero_or_one_space, + insert_space, +) +from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst as OrdinalTagger +from nemo_text_processing.text_normalization.en.taggers.whitelist import get_formats +from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as OrdinalVerbalizer +from pynini.examples import plurals +from pynini.lib import pynutil + + +class MeasureFst(GraphFst): + """ + Finite state transducer for classifying measure, suppletive aware, e.g. + -12kg -> measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" } + 1kg -> measure { cardinal { integer: "one" } units: "kilogram" } + .5kg -> measure { decimal { fractional_part: "five" } units: "kilograms" } + + Args: + cardinal: CardinalFst + decimal: DecimalFst + fraction: FractionFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal: GraphFst, decimal: GraphFst, fraction: GraphFst, deterministic: bool = True): + super().__init__(name="measure", kind="classify", deterministic=deterministic) + cardinal_graph = cardinal.graph_with_and | self.get_range(cardinal.graph_with_and) + + graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv")) + if not deterministic: + graph_unit |= pynini.string_file(get_abs_path("data/measure/unit_alternatives.tsv")) + + graph_unit |= pynini.compose( + pynini.closure(TO_LOWER, 1) + (NEMO_ALPHA | TO_LOWER) + pynini.closure(NEMO_ALPHA | TO_LOWER), graph_unit + ).optimize() + + graph_unit_plural = convert_space(graph_unit @ SINGULAR_TO_PLURAL) + graph_unit = convert_space(graph_unit) + + optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1) + + graph_unit2 = ( + pynini.cross("/", "per") + delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit + ) + + optional_graph_unit2 = pynini.closure( + delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit2, 0, 1, + ) + + unit_plural = ( + pynutil.insert("units: \"") + + (graph_unit_plural + optional_graph_unit2 | graph_unit2) + + pynutil.insert("\"") + ) + + unit_singular = ( + pynutil.insert("units: \"") + (graph_unit + optional_graph_unit2 | graph_unit2) + pynutil.insert("\"") + ) + + subgraph_decimal = ( + pynutil.insert("decimal { ") + + optional_graph_negative + + decimal.final_graph_wo_negative + + delete_space + + pynutil.insert(" } ") + + unit_plural + ) + + # support radio FM/AM + subgraph_decimal |= ( + pynutil.insert("decimal { ") + + decimal.final_graph_wo_negative + + delete_space + + pynutil.insert(" } ") + + pynutil.insert("units: \"") + + pynini.union("AM", "FM") + + pynutil.insert("\"") + ) + + subgraph_cardinal = ( + pynutil.insert("cardinal { ") + + optional_graph_negative + + pynutil.insert("integer: \"") + + ((NEMO_SIGMA - "1") @ cardinal_graph) + + delete_space + + pynutil.insert("\"") + + pynutil.insert(" } ") + + unit_plural + ) + + subgraph_cardinal |= ( + pynutil.insert("cardinal { ") + + optional_graph_negative + + pynutil.insert("integer: \"") + + pynini.cross("1", "one") + + delete_space + + pynutil.insert("\"") + + pynutil.insert(" } ") + + unit_singular + ) + + unit_graph = ( + pynutil.insert("cardinal { integer: \"-\" } units: \"") + + pynini.cross(pynini.union("/", "per"), "per") + + delete_zero_or_one_space + + pynutil.insert(NEMO_NON_BREAKING_SPACE) + + graph_unit + + pynutil.insert("\" preserve_order: true") + ) + + decimal_dash_alpha = ( + pynutil.insert("decimal { ") + + decimal.final_graph_wo_negative + + pynini.cross('-', '') + + pynutil.insert(" } units: \"") + + pynini.closure(NEMO_ALPHA, 1) + + pynutil.insert("\"") + ) + + decimal_times = ( + pynutil.insert("decimal { ") + + decimal.final_graph_wo_negative + + pynutil.insert(" } units: \"") + + pynini.cross(pynini.union('x', "X"), 'x') + + pynutil.insert("\"") + ) + + alpha_dash_decimal = ( + pynutil.insert("units: \"") + + pynini.closure(NEMO_ALPHA, 1) + + pynini.accep('-') + + pynutil.insert("\"") + + pynutil.insert(" decimal { ") + + decimal.final_graph_wo_negative + + pynutil.insert(" } preserve_order: true") + ) + + subgraph_fraction = ( + pynutil.insert("fraction { ") + fraction.graph + delete_space + pynutil.insert(" } ") + unit_plural + ) + + address = self.get_address_graph(cardinal) + address = ( + pynutil.insert("units: \"address\" cardinal { integer: \"") + + address + + pynutil.insert("\" } preserve_order: true") + ) + + math_operations = pynini.string_file(get_abs_path("data/measure/math_operation.tsv")) + delimiter = pynini.accep(" ") | pynutil.insert(" ") + + math = ( + (cardinal_graph | NEMO_ALPHA) + + delimiter + + math_operations + + (delimiter | NEMO_ALPHA) + + cardinal_graph + + delimiter + + pynini.cross("=", "equals") + + delimiter + + (cardinal_graph | NEMO_ALPHA) + ) + + math |= ( + (cardinal_graph | NEMO_ALPHA) + + delimiter + + pynini.cross("=", "equals") + + delimiter + + (cardinal_graph | NEMO_ALPHA) + + delimiter + + math_operations + + delimiter + + cardinal_graph + ) + + math = ( + pynutil.insert("units: \"math\" cardinal { integer: \"") + + math + + pynutil.insert("\" } preserve_order: true") + ) + final_graph = ( + subgraph_decimal + | subgraph_cardinal + | unit_graph + | decimal_dash_alpha + | decimal_times + | alpha_dash_decimal + | subgraph_fraction + | address + | math + ) + + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() + + def get_range(self, cardinal: GraphFst): + """ + Returns range forms for measure tagger, e.g. 2-3, 2x3, 2*2 + + Args: + cardinal: cardinal GraphFst + """ + range_graph = cardinal + pynini.cross(pynini.union("-", " - "), " to ") + cardinal + + for x in [" x ", "x"]: + range_graph |= cardinal + pynini.cross(x, " by ") + cardinal + if not self.deterministic: + range_graph |= cardinal + pynini.cross(x, " times ") + cardinal + + for x in ["*", " * "]: + range_graph |= cardinal + pynini.cross(x, " times ") + cardinal + return range_graph.optimize() + + def get_address_graph(self, cardinal): + """ + Finite state transducer for classifying serial. + The serial is a combination of digits, letters and dashes, e.g.: + 2788 San Tomas Expy, Santa Clara, CA 95051 -> + units: "address" cardinal + { integer: "two seven eight eight San Tomas Expressway Santa Clara California nine five zero five one" } + preserve_order: true + """ + ordinal_verbalizer = OrdinalVerbalizer().graph + ordinal_tagger = OrdinalTagger(cardinal=cardinal).graph + ordinal_num = pynini.compose( + pynutil.insert("integer: \"") + ordinal_tagger + pynutil.insert("\""), ordinal_verbalizer + ) + + address_num = NEMO_DIGIT ** (1, 2) @ cardinal.graph_hundred_component_at_least_one_none_zero_digit + address_num += insert_space + NEMO_DIGIT ** 2 @ ( + pynini.closure(pynini.cross("0", "zero "), 0, 1) + + cardinal.graph_hundred_component_at_least_one_none_zero_digit + ) + # to handle the rest of the numbers + address_num = pynini.compose(NEMO_DIGIT ** (3, 4), address_num) + address_num = plurals._priority_union(address_num, cardinal.graph, NEMO_SIGMA) + + direction = ( + pynini.cross("E", "East") + | pynini.cross("S", "South") + | pynini.cross("W", "West") + | pynini.cross("N", "North") + ) + pynini.closure(pynutil.delete("."), 0, 1) + + direction = pynini.closure(pynini.accep(NEMO_SPACE) + direction, 0, 1) + address_words = get_formats(get_abs_path("data/address/address_word.tsv")) + address_words = ( + pynini.accep(NEMO_SPACE) + + (pynini.closure(ordinal_num, 0, 1) | NEMO_UPPER + pynini.closure(NEMO_ALPHA, 1)) + + NEMO_SPACE + + pynini.closure(NEMO_UPPER + pynini.closure(NEMO_ALPHA) + NEMO_SPACE) + + address_words + ) + + city = pynini.closure(NEMO_ALPHA | pynini.accep(NEMO_SPACE), 1) + city = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + city, 0, 1) + + states = load_labels(get_abs_path("data/address/state.tsv")) + + additional_options = [] + for x, y in states: + additional_options.append((x, f"{y[0]}.{y[1:]}")) + states.extend(additional_options) + state_graph = pynini.string_map(states) + state = pynini.invert(state_graph) + state = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + state, 0, 1) + + zip_code = pynini.compose(NEMO_DIGIT ** 5, cardinal.single_digits_graph) + zip_code = pynini.closure(pynini.closure(pynini.accep(","), 0, 1) + pynini.accep(NEMO_SPACE) + zip_code, 0, 1,) + + address = address_num + direction + address_words + pynini.closure(city + state + zip_code, 0, 1) + + address |= address_num + direction + address_words + pynini.closure(pynini.cross(".", ""), 0, 1) + + return address diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/money.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/money.py new file mode 100644 index 0000000..43e26bd --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/money.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_SIGMA, + SINGULAR_TO_PLURAL, + GraphFst, + convert_space, + insert_space, +) +from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +from pynini.lib import pynutil + +min_singular = pynini.string_file(get_abs_path("data/money/currency_minor_singular.tsv")) +min_plural = pynini.string_file(get_abs_path("data/money/currency_minor_plural.tsv")) +maj_singular = pynini.string_file((get_abs_path("data/money/currency_major.tsv"))) + + +class MoneyFst(GraphFst): + """ + Finite state transducer for classifying money, suppletive aware, e.g. + $12.05 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true } + $12.0500 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true } + $1 -> money { currency_maj: "dollar" integer_part: "one" } + $1.00 -> money { currency_maj: "dollar" integer_part: "one" } + $0.05 -> money { fractional_part: "five" currency_min: "cents" preserve_order: true } + $1 million -> money { currency_maj: "dollars" integer_part: "one" quantity: "million" } + $1.2 million -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two" quantity: "million" } + $1.2320 -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two three two" } + + Args: + cardinal: CardinalFst + decimal: DecimalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal: GraphFst, decimal: GraphFst, deterministic: bool = True): + super().__init__(name="money", kind="classify", deterministic=deterministic) + cardinal_graph = cardinal.graph_with_and + graph_decimal_final = decimal.final_graph_wo_negative_w_abbr + + maj_singular_labels = load_labels(get_abs_path("data/money/currency_major.tsv")) + maj_unit_plural = convert_space(maj_singular @ SINGULAR_TO_PLURAL) + maj_unit_singular = convert_space(maj_singular) + + graph_maj_singular = pynutil.insert("currency_maj: \"") + maj_unit_singular + pynutil.insert("\"") + graph_maj_plural = pynutil.insert("currency_maj: \"") + maj_unit_plural + pynutil.insert("\"") + + optional_delete_fractional_zeros = pynini.closure( + pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), 0, 1 + ) + + graph_integer_one = pynutil.insert("integer_part: \"") + pynini.cross("1", "one") + pynutil.insert("\"") + # only for decimals where third decimal after comma is non-zero or with quantity + decimal_delete_last_zeros = ( + pynini.closure(NEMO_DIGIT | pynutil.delete(",")) + + pynini.accep(".") + + pynini.closure(NEMO_DIGIT, 2) + + (NEMO_DIGIT - "0") + + pynini.closure(pynutil.delete("0")) + ) + decimal_with_quantity = NEMO_SIGMA + NEMO_ALPHA + + graph_decimal = ( + graph_maj_plural + insert_space + (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final + ) + + graph_integer = ( + pynutil.insert("integer_part: \"") + ((NEMO_SIGMA - "1") @ cardinal_graph) + pynutil.insert("\"") + ) + + graph_integer_only = graph_maj_singular + insert_space + graph_integer_one + graph_integer_only |= graph_maj_plural + insert_space + graph_integer + + final_graph = (graph_integer_only + optional_delete_fractional_zeros) | graph_decimal + + # remove trailing zeros of non zero number in the first 2 digits and fill up to 2 digits + # e.g. 2000 -> 20, 0200->02, 01 -> 01, 10 -> 10 + # not accepted: 002, 00, 0, + two_digits_fractional_part = ( + pynini.closure(NEMO_DIGIT) + (NEMO_DIGIT - "0") + pynini.closure(pynutil.delete("0")) + ) @ ( + (pynutil.delete("0") + (NEMO_DIGIT - "0")) + | ((NEMO_DIGIT - "0") + pynutil.insert("0")) + | ((NEMO_DIGIT - "0") + NEMO_DIGIT) + ) + + graph_min_singular = pynutil.insert(" currency_min: \"") + min_singular + pynutil.insert("\"") + graph_min_plural = pynutil.insert(" currency_min: \"") + min_plural + pynutil.insert("\"") + # format ** dollars ** cent + decimal_graph_with_minor = None + integer_graph_reordered = None + decimal_default_reordered = None + for curr_symbol, _ in maj_singular_labels: + preserve_order = pynutil.insert(" preserve_order: true") + integer_plus_maj = graph_integer + insert_space + pynutil.insert(curr_symbol) @ graph_maj_plural + integer_plus_maj |= graph_integer_one + insert_space + pynutil.insert(curr_symbol) @ graph_maj_singular + + integer_plus_maj_with_comma = pynini.compose( + NEMO_DIGIT - "0" + pynini.closure(NEMO_DIGIT | pynutil.delete(",")), integer_plus_maj + ) + integer_plus_maj = pynini.compose(pynini.closure(NEMO_DIGIT) - "0", integer_plus_maj) + integer_plus_maj |= integer_plus_maj_with_comma + + graph_fractional_one = two_digits_fractional_part @ pynini.cross("1", "one") + graph_fractional_one = pynutil.insert("fractional_part: \"") + graph_fractional_one + pynutil.insert("\"") + graph_fractional = ( + two_digits_fractional_part + @ (pynini.closure(NEMO_DIGIT, 1, 2) - "1") + @ cardinal.graph_hundred_component_at_least_one_none_zero_digit + ) + graph_fractional = pynutil.insert("fractional_part: \"") + graph_fractional + pynutil.insert("\"") + + fractional_plus_min = graph_fractional + insert_space + pynutil.insert(curr_symbol) @ graph_min_plural + fractional_plus_min |= ( + graph_fractional_one + insert_space + pynutil.insert(curr_symbol) @ graph_min_singular + ) + + decimal_graph_with_minor_curr = integer_plus_maj + pynini.cross(".", " ") + fractional_plus_min + + if not deterministic: + decimal_graph_with_minor_curr |= pynutil.add_weight( + integer_plus_maj + + pynini.cross(".", " ") + + pynutil.insert("fractional_part: \"") + + two_digits_fractional_part @ cardinal.graph_hundred_component_at_least_one_none_zero_digit + + pynutil.insert("\""), + weight=0.0001, + ) + default_fraction_graph = (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final + decimal_graph_with_minor_curr |= ( + pynini.closure(pynutil.delete("0"), 0, 1) + pynutil.delete(".") + fractional_plus_min + ) + decimal_graph_with_minor_curr = ( + pynutil.delete(curr_symbol) + decimal_graph_with_minor_curr + preserve_order + ) + + decimal_graph_with_minor = ( + decimal_graph_with_minor_curr + if decimal_graph_with_minor is None + else pynini.union(decimal_graph_with_minor, decimal_graph_with_minor_curr).optimize() + ) + + if not deterministic: + integer_graph_reordered_curr = ( + pynutil.delete(curr_symbol) + integer_plus_maj + preserve_order + ).optimize() + + integer_graph_reordered = ( + integer_graph_reordered_curr + if integer_graph_reordered is None + else pynini.union(integer_graph_reordered, integer_graph_reordered_curr).optimize() + ) + decimal_default_reordered_curr = ( + pynutil.delete(curr_symbol) + + default_fraction_graph + + insert_space + + pynutil.insert(curr_symbol) @ graph_maj_plural + ) + + decimal_default_reordered = ( + decimal_default_reordered_curr + if decimal_default_reordered is None + else pynini.union(decimal_default_reordered, decimal_default_reordered_curr) + ).optimize() + + # weight for SH + final_graph |= pynutil.add_weight(decimal_graph_with_minor, -0.0001) + + if not deterministic: + final_graph |= integer_graph_reordered | decimal_default_reordered + # to handle "$2.00" cases + final_graph |= pynini.compose( + NEMO_SIGMA + pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), integer_graph_reordered + ) + final_graph = self.add_tokens(final_graph.optimize()) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/ordinal.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/ordinal.py new file mode 100644 index 0000000..1ea56c9 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/ordinal.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst +from pynini.lib import pynutil + + +class OrdinalFst(GraphFst): + """ + Finite state transducer for classifying ordinal, e.g. + 13th -> ordinal { integer: "thirteen" } + + Args: + cardinal: CardinalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal: GraphFst, deterministic: bool = True): + super().__init__(name="ordinal", kind="classify", deterministic=deterministic) + + cardinal_graph = cardinal.graph + cardinal_format = pynini.closure(NEMO_DIGIT | pynini.accep(",")) + st_format = ( + pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1) + + pynini.accep("1") + + pynutil.delete(pynini.union("st", "ST")) + ) + nd_format = ( + pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1) + + pynini.accep("2") + + pynutil.delete(pynini.union("nd", "ND")) + ) + rd_format = ( + pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1) + + pynini.accep("3") + + pynutil.delete(pynini.union("rd", "RD")) + ) + th_format = pynini.closure( + (NEMO_DIGIT - "1" - "2" - "3") + | (cardinal_format + "1" + NEMO_DIGIT) + | (cardinal_format + (NEMO_DIGIT - "1") + (NEMO_DIGIT - "1" - "2" - "3")), + 1, + ) + pynutil.delete(pynini.union("th", "TH")) + self.graph = (st_format | nd_format | rd_format | th_format) @ cardinal_graph + final_graph = pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"") + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/punctuation.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/punctuation.py new file mode 100644 index 0000000..769b020 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/punctuation.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unicodedata import category + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_SPACE, NEMO_SIGMA, GraphFst +from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +from pynini.examples import plurals +from pynini.lib import pynutil + + +class PunctuationFst(GraphFst): + """ + Finite state transducer for classifying punctuation + e.g. a, -> tokens { name: "a" } tokens { name: "," } + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="punctuation", kind="classify", deterministic=deterministic) + s = "!#%&\'()*+,-./:;<=>?@^_`{|}~\"" + + punct_symbols_to_exclude = ["[", "]"] + punct_unicode = [ + chr(i) + for i in range(sys.maxunicode) + if category(chr(i)).startswith("P") and chr(i) not in punct_symbols_to_exclude + ] + + whitelist_symbols = load_labels(get_abs_path("data/whitelist/symbol.tsv")) + whitelist_symbols = [x[0] for x in whitelist_symbols] + self.punct_marks = [p for p in punct_unicode + list(s) if p not in whitelist_symbols] + + punct = pynini.union(*self.punct_marks) + punct = pynini.closure(punct, 1) + + emphasis = ( + pynini.accep("<") + + ( + (pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1) + pynini.closure(pynini.accep("/"), 0, 1)) + | (pynini.accep("/") + pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1)) + ) + + pynini.accep(">") + ) + punct = plurals._priority_union(emphasis, punct, NEMO_SIGMA) + + self.graph = punct + self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/range.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/range.py new file mode 100644 index 0000000..9c237f9 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/range.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst, convert_space +from pynini.lib import pynutil + + +class RangeFst(GraphFst): + """ + This class is a composite class of two other class instances + + Args: + time: composed tagger and verbalizer + date: composed tagger and verbalizer + cardinal: tagger + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + lm: whether to use for hybrid LM + """ + + def __init__( + self, time: GraphFst, date: GraphFst, cardinal: GraphFst, deterministic: bool = True, lm: bool = False, + ): + super().__init__(name="range", kind="classify", deterministic=deterministic) + + delete_space = pynini.closure(pynutil.delete(" "), 0, 1) + + approx = pynini.cross("~", "approximately") + + # TIME + time_graph = time + delete_space + pynini.cross("-", " to ") + delete_space + time + self.graph = time_graph | (approx + time) + + cardinal = cardinal.graph_with_and + # YEAR + date_year_four_digit = (NEMO_DIGIT ** 4 + pynini.closure(pynini.accep("s"), 0, 1)) @ date + date_year_two_digit = (NEMO_DIGIT ** 2 + pynini.closure(pynini.accep("s"), 0, 1)) @ date + year_to_year_graph = ( + date_year_four_digit + + delete_space + + pynini.cross("-", " to ") + + delete_space + + (date_year_four_digit | date_year_two_digit | (NEMO_DIGIT ** 2 @ cardinal)) + ) + mid_year_graph = pynini.accep("mid") + pynini.cross("-", " ") + (date_year_four_digit | date_year_two_digit) + + self.graph |= year_to_year_graph + self.graph |= mid_year_graph + + # ADDITION + range_graph = cardinal + pynini.closure(pynini.cross("+", " plus ") + cardinal, 1) + range_graph |= cardinal + pynini.closure(pynini.cross(" + ", " plus ") + cardinal, 1) + range_graph |= approx + cardinal + range_graph |= cardinal + (pynini.cross("...", " ... ") | pynini.accep(" ... ")) + cardinal + + if not deterministic or lm: + # cardinal ---- + cardinal_to_cardinal_graph = ( + cardinal + delete_space + pynini.cross("-", pynini.union(" to ", " minus ")) + delete_space + cardinal + ) + + range_graph |= cardinal_to_cardinal_graph | ( + cardinal + delete_space + pynini.cross(":", " to ") + delete_space + cardinal + ) + + # MULTIPLY + for x in [" x ", "x"]: + range_graph |= cardinal + pynini.closure( + pynini.cross(x, pynini.union(" by ", " times ")) + cardinal, 1 + ) + + for x in ["*", " * "]: + range_graph |= cardinal + pynini.closure(pynini.cross(x, " times ") + cardinal, 1) + + # supports "No. 12" -> "Number 12" + range_graph |= ( + (pynini.cross(pynini.union("NO", "No"), "Number") | pynini.cross("no", "number")) + + pynini.closure(pynini.union(". ", " "), 0, 1) + + cardinal + ) + + for x in ["/", " / "]: + range_graph |= cardinal + pynini.closure(pynini.cross(x, " divided by ") + cardinal, 1) + + self.graph |= range_graph + + self.graph = self.graph.optimize() + graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"") + self.fst = graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/roman.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/roman.py new file mode 100644 index 0000000..e12ee4a --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/roman.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_ALPHA, NEMO_SIGMA, GraphFst +from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +from pynini.lib import pynutil + + +class RomanFst(GraphFst): + """ + Finite state transducer for classifying roman numbers: + e.g. "IV" -> tokens { roman { integer: "four" } } + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True, lm: bool = False): + super().__init__(name="roman", kind="classify", deterministic=deterministic) + + roman_dict = load_labels(get_abs_path("data/roman/roman_to_spoken.tsv")) + default_graph = pynini.string_map(roman_dict).optimize() + default_graph = pynutil.insert("integer: \"") + default_graph + pynutil.insert("\"") + ordinal_limit = 19 + + if deterministic: + # exclude "I" + start_idx = 1 + else: + start_idx = 0 + + graph_teens = pynini.string_map([x[0] for x in roman_dict[start_idx:ordinal_limit]]).optimize() + + # roman numerals up to ordinal_limit with a preceding name are converted to ordinal form + names = get_names() + graph = ( + pynutil.insert("key_the_ordinal: \"") + + names + + pynutil.insert("\"") + + pynini.accep(" ") + + graph_teens @ default_graph + ).optimize() + + # single symbol roman numerals with preceding key words (multiple formats) are converted to cardinal form + key_words = [] + for k_word in load_labels(get_abs_path("data/roman/key_word.tsv")): + key_words.append(k_word) + key_words.append([k_word[0][0].upper() + k_word[0][1:]]) + key_words.append([k_word[0].upper()]) + + key_words = pynini.string_map(key_words).optimize() + graph |= ( + pynutil.insert("key_cardinal: \"") + key_words + pynutil.insert("\"") + pynini.accep(" ") + default_graph + ).optimize() + + if deterministic or lm: + # two digit roman numerals up to 49 + roman_to_cardinal = pynini.compose( + pynini.closure(NEMO_ALPHA, 2), + ( + pynutil.insert("default_cardinal: \"default\" ") + + (pynini.string_map([x[0] for x in roman_dict[:50]]).optimize()) @ default_graph + ), + ) + graph |= roman_to_cardinal + elif not lm: + # two or more digit roman numerals + roman_to_cardinal = pynini.compose( + pynini.difference(NEMO_SIGMA, "I"), + ( + pynutil.insert("default_cardinal: \"default\" integer: \"") + + pynini.string_map(roman_dict).optimize() + + pynutil.insert("\"") + ), + ).optimize() + graph |= roman_to_cardinal + + # convert three digit roman or up with suffix to ordinal + roman_to_ordinal = pynini.compose( + pynini.closure(NEMO_ALPHA, 3), + (pynutil.insert("default_ordinal: \"default\" ") + graph_teens @ default_graph + pynutil.delete("th")), + ) + + graph |= roman_to_ordinal + graph = self.add_tokens(graph.optimize()) + + self.fst = graph.optimize() + + +def get_names(): + """ + Returns the graph that matched common male and female names. + """ + male_labels = load_labels(get_abs_path("data/roman/male.tsv")) + female_labels = load_labels(get_abs_path("data/roman/female.tsv")) + male_labels.extend([[x[0].upper()] for x in male_labels]) + female_labels.extend([[x[0].upper()] for x in female_labels]) + names = pynini.string_map(male_labels).optimize() + names |= pynini.string_map(female_labels).optimize() + return names diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/serial.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/serial.py new file mode 100644 index 0000000..669fd95 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/serial.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_NOT_SPACE, + NEMO_SIGMA, + GraphFst, + convert_space, +) +from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels +from pynini.examples import plurals +from pynini.lib import pynutil + + +class SerialFst(GraphFst): + """ + This class is a composite class of two other class instances + + Args: + time: composed tagger and verbalizer + date: composed tagger and verbalizer + cardinal: tagger + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + lm: whether to use for hybrid LM + """ + + def __init__(self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = True, lm: bool = False): + super().__init__(name="integer", kind="classify", deterministic=deterministic) + + """ + Finite state transducer for classifying serial (handles only cases without delimiters, + values with delimiters are handled by default). + The serial is a combination of digits, letters and dashes, e.g.: + c325b -> tokens { cardinal { integer: "c three two five b" } } + """ + num_graph = pynini.compose(NEMO_DIGIT ** (6, ...), cardinal.single_digits_graph).optimize() + num_graph |= pynini.compose(NEMO_DIGIT ** (1, 5), cardinal.graph).optimize() + # to handle numbers starting with zero + num_graph |= pynini.compose( + pynini.accep("0") + pynini.closure(NEMO_DIGIT), cardinal.single_digits_graph + ).optimize() + # TODO: "#" doesn't work from the file + symbols_graph = pynini.string_file(get_abs_path("data/whitelist/symbol.tsv")).optimize() | pynini.cross( + "#", "hash" + ) + num_graph |= symbols_graph + + if not self.deterministic and not lm: + num_graph |= cardinal.single_digits_graph + # also allow double digits to be pronounced as integer in serial number + num_graph |= pynutil.add_weight( + NEMO_DIGIT ** 2 @ cardinal.graph_hundred_component_at_least_one_none_zero_digit, weight=0.0001 + ) + + # add space between letter and digit/symbol + symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))] + symbols = pynini.union(*symbols) + digit_symbol = NEMO_DIGIT | symbols + + graph_with_space = pynini.compose( + pynini.cdrewrite(pynutil.insert(" "), NEMO_ALPHA | symbols, digit_symbol, NEMO_SIGMA), + pynini.cdrewrite(pynutil.insert(" "), digit_symbol, NEMO_ALPHA | symbols, NEMO_SIGMA), + ) + + # serial graph with delimiter + delimiter = pynini.accep("-") | pynini.accep("/") | pynini.accep(" ") + if not deterministic: + delimiter |= pynini.cross("-", " dash ") | pynini.cross("/", " slash ") + + alphas = pynini.closure(NEMO_ALPHA, 1) + letter_num = alphas + delimiter + num_graph + num_letter = pynini.closure(num_graph + delimiter, 1) + alphas + next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph)) + next_alpha_or_num |= pynini.closure( + delimiter + + num_graph + + plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), NEMO_SIGMA).optimize() + + alphas + ) + + serial_graph = letter_num + next_alpha_or_num + serial_graph |= num_letter + next_alpha_or_num + # numbers only with 2+ delimiters + serial_graph |= ( + num_graph + delimiter + num_graph + delimiter + num_graph + pynini.closure(delimiter + num_graph) + ) + # 2+ symbols + serial_graph |= pynini.compose(NEMO_SIGMA + symbols + NEMO_SIGMA, num_graph + delimiter + num_graph) + + # exclude ordinal numbers from serial options + serial_graph = pynini.compose( + pynini.difference(NEMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph + ).optimize() + + serial_graph = pynutil.add_weight(serial_graph, 0.0001) + serial_graph |= ( + pynini.closure(NEMO_NOT_SPACE, 1) + + (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize() + ) + + # at least one serial graph with alpha numeric value and optional additional serial/num/alpha values + serial_graph = ( + pynini.closure((serial_graph | num_graph | alphas) + delimiter) + + serial_graph + + pynini.closure(delimiter + (serial_graph | num_graph | alphas)) + ) + + serial_graph |= pynini.compose(graph_with_space, serial_graph.optimize()).optimize() + serial_graph = pynini.compose(pynini.closure(NEMO_NOT_SPACE, 2), serial_graph).optimize() + + # this is not to verbolize "/" as "slash" in cases like "import/export" + serial_graph = pynini.compose( + pynini.difference( + NEMO_SIGMA, pynini.closure(NEMO_ALPHA, 1) + pynini.accep("/") + pynini.closure(NEMO_ALPHA, 1) + ), + serial_graph, + ) + self.graph = serial_graph.optimize() + graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"") + self.fst = graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/telephone.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/telephone.py new file mode 100644 index 0000000..1caedff --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/telephone.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_SIGMA, + GraphFst, + delete_extra_space, + delete_space, + insert_space, + plurals, +) +from nemo_text_processing.text_normalization.en.utils import get_abs_path +from pynini.lib import pynutil + + +class TelephoneFst(GraphFst): + """ + Finite state transducer for classifying telephone, and IP, and SSN which includes country code, number part and extension + country code optional: +*** + number part: ***-***-****, or (***) ***-**** + extension optional: 1-9999 + E.g + +1 123-123-5678-1 -> telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" } + 1-800-GO-U-HAUL -> telephone { country_code: "one" number_part: "one, eight hundred GO U HAUL" } + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="telephone", kind="classify", deterministic=deterministic) + + add_separator = pynutil.insert(", ") # between components + zero = pynini.cross("0", "zero") + if not deterministic: + zero |= pynini.cross("0", pynini.union("o", "oh")) + digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize() | zero + + telephone_prompts = pynini.string_file(get_abs_path("data/telephone/telephone_prompt.tsv")) + country_code = ( + pynini.closure(telephone_prompts + delete_extra_space, 0, 1) + + pynini.closure(pynini.cross("+", "plus "), 0, 1) + + pynini.closure(digit + insert_space, 0, 2) + + digit + + pynutil.insert(",") + ) + country_code |= telephone_prompts + country_code = pynutil.insert("country_code: \"") + country_code + pynutil.insert("\"") + country_code = country_code + pynini.closure(pynutil.delete("-"), 0, 1) + delete_space + insert_space + + area_part_default = pynini.closure(digit + insert_space, 2, 2) + digit + area_part = pynini.cross("800", "eight hundred") | pynini.compose( + pynini.difference(NEMO_SIGMA, "800"), area_part_default + ) + + area_part = ( + (area_part + (pynutil.delete("-") | pynutil.delete("."))) + | ( + pynutil.delete("(") + + area_part + + ((pynutil.delete(")") + pynini.closure(pynutil.delete(" "), 0, 1)) | pynutil.delete(")-")) + ) + ) + add_separator + + del_separator = pynini.closure(pynini.union("-", " ", "."), 0, 1) + number_length = ((NEMO_DIGIT + del_separator) | (NEMO_ALPHA + del_separator)) ** 7 + number_words = pynini.closure( + (NEMO_DIGIT @ digit) + (insert_space | (pynini.cross("-", ', '))) + | NEMO_ALPHA + | (NEMO_ALPHA + pynini.cross("-", ' ')) + ) + number_words |= pynini.closure( + (NEMO_DIGIT @ digit) + (insert_space | (pynini.cross(".", ', '))) + | NEMO_ALPHA + | (NEMO_ALPHA + pynini.cross(".", ' ')) + ) + number_words = pynini.compose(number_length, number_words) + number_part = area_part + number_words + number_part = pynutil.insert("number_part: \"") + number_part + pynutil.insert("\"") + extension = ( + pynutil.insert("extension: \"") + pynini.closure(digit + insert_space, 0, 3) + digit + pynutil.insert("\"") + ) + extension = pynini.closure(insert_space + extension, 0, 1) + + graph = plurals._priority_union(country_code + number_part, number_part, NEMO_SIGMA).optimize() + graph = plurals._priority_union(country_code + number_part + extension, graph, NEMO_SIGMA).optimize() + graph = plurals._priority_union(number_part + extension, graph, NEMO_SIGMA).optimize() + + # ip + ip_prompts = pynini.string_file(get_abs_path("data/telephone/ip_prompt.tsv")) + digit_to_str_graph = digit + pynini.closure(pynutil.insert(" ") + digit, 0, 2) + ip_graph = digit_to_str_graph + (pynini.cross(".", " dot ") + digit_to_str_graph) ** 3 + graph |= ( + pynini.closure( + pynutil.insert("country_code: \"") + ip_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1 + ) + + pynutil.insert("number_part: \"") + + ip_graph.optimize() + + pynutil.insert("\"") + ) + # ssn + ssn_prompts = pynini.string_file(get_abs_path("data/telephone/ssn_prompt.tsv")) + three_digit_part = digit + (pynutil.insert(" ") + digit) ** 2 + two_digit_part = digit + pynutil.insert(" ") + digit + four_digit_part = digit + (pynutil.insert(" ") + digit) ** 3 + ssn_separator = pynini.cross("-", ", ") + ssn_graph = three_digit_part + ssn_separator + two_digit_part + ssn_separator + four_digit_part + + graph |= ( + pynini.closure( + pynutil.insert("country_code: \"") + ssn_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1 + ) + + pynutil.insert("number_part: \"") + + ssn_graph.optimize() + + pynutil.insert("\"") + ) + + final_graph = self.add_tokens(graph) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/time.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/time.py new file mode 100644 index 0000000..4020996 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/time.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_DIGIT, + GraphFst, + convert_space, + delete_space, + insert_space, +) +from nemo_text_processing.text_normalization.en.utils import ( + augment_labels_with_punct_at_end, + get_abs_path, + load_labels, +) +from pynini.lib import pynutil + + +class TimeFst(GraphFst): + """ + Finite state transducer for classifying time, e.g. + 12:30 a.m. est -> time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" } + 2.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" } + 02.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" } + 2.00 a.m. -> time { hours: "two" suffix: "a m" } + 2 a.m. -> time { hours: "two" suffix: "a m" } + 02:00 -> time { hours: "two" } + 2:00 -> time { hours: "two" } + 10:00:05 a.m. -> time { hours: "ten" minutes: "zero" seconds: "five" suffix: "a m" } + + Args: + cardinal: CardinalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal: GraphFst, deterministic: bool = True): + super().__init__(name="time", kind="classify", deterministic=deterministic) + suffix_labels = load_labels(get_abs_path("data/time/suffix.tsv")) + suffix_labels.extend(augment_labels_with_punct_at_end(suffix_labels)) + suffix_graph = pynini.string_map(suffix_labels) + + time_zone_graph = pynini.string_file(get_abs_path("data/time/zone.tsv")) + + # only used for < 1000 thousand -> 0 weight + cardinal = cardinal.graph + + labels_hour = [str(x) for x in range(0, 24)] + labels_minute_single = [str(x) for x in range(1, 10)] + labels_minute_double = [str(x) for x in range(10, 60)] + + delete_leading_zero_to_double_digit = (NEMO_DIGIT + NEMO_DIGIT) | ( + pynini.closure(pynutil.delete("0"), 0, 1) + NEMO_DIGIT + ) + + graph_hour = delete_leading_zero_to_double_digit @ pynini.union(*labels_hour) @ cardinal + + graph_minute_single = pynini.union(*labels_minute_single) @ cardinal + graph_minute_double = pynini.union(*labels_minute_double) @ cardinal + + final_graph_hour = pynutil.insert("hours: \"") + graph_hour + pynutil.insert("\"") + final_graph_minute = ( + pynutil.insert("minutes: \"") + + (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double) + + pynutil.insert("\"") + ) + final_graph_second = ( + pynutil.insert("seconds: \"") + + (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double) + + pynutil.insert("\"") + ) + final_suffix = pynutil.insert("suffix: \"") + convert_space(suffix_graph) + pynutil.insert("\"") + final_suffix_optional = pynini.closure(delete_space + insert_space + final_suffix, 0, 1) + final_time_zone_optional = pynini.closure( + delete_space + + insert_space + + pynutil.insert("zone: \"") + + convert_space(time_zone_graph) + + pynutil.insert("\""), + 0, + 1, + ) + + # 2:30 pm, 02:30, 2:00 + graph_hm = ( + final_graph_hour + + pynutil.delete(":") + + (pynutil.delete("00") | insert_space + final_graph_minute) + + final_suffix_optional + + final_time_zone_optional + ) + + # 10:30:05 pm, + graph_hms = ( + final_graph_hour + + pynutil.delete(":") + + (pynini.cross("00", " minutes: \"zero\"") | insert_space + final_graph_minute) + + pynutil.delete(":") + + (pynini.cross("00", " seconds: \"zero\"") | insert_space + final_graph_second) + + final_suffix_optional + + final_time_zone_optional + ) + + # 2.xx pm/am + graph_hm2 = ( + final_graph_hour + + pynutil.delete(".") + + (pynutil.delete("00") | insert_space + final_graph_minute) + + delete_space + + insert_space + + final_suffix + + final_time_zone_optional + ) + # 2 pm est + graph_h = final_graph_hour + delete_space + insert_space + final_suffix + final_time_zone_optional + final_graph = (graph_hm | graph_h | graph_hm2 | graph_hms).optimize() + + final_graph = self.add_tokens(final_graph) + self.fst = final_graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify.py new file mode 100644 index 0000000..53ae71e --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_WHITE_SPACE, + GraphFst, + delete_extra_space, + delete_space, + generator_main, +) +from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst +from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst +from nemo_text_processing.text_normalization.en.taggers.date import DateFst +from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst +from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst +from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst +from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst +from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst +from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst +from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst +from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst +from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst +from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst +from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst +from nemo_text_processing.text_normalization.en.taggers.time import TimeFst +from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst +from nemo_text_processing.text_normalization.en.taggers.word import WordFst +from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDateFst +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinalFst +from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTimeFst +from pynini.lib import pynutil + + + +class ClassifyFst(GraphFst): + """ + Final class that composes all other classification grammars. This class can process an entire sentence including punctuation. + For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File. + More details to deployment at NeMo/tools/text_processing_deployment. + + Args: + input_case: accepting either "lower_cased" or "cased" input. + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + whitelist: path to a file with whitelist replacements + """ + + def __init__( + self, + input_case: str, + deterministic: bool = True, + cache_dir: str = None, + overwrite_cache: bool = False, + whitelist: str = None, + ): + super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic) + + far_file = None + if cache_dir is not None and cache_dir != "None": + os.makedirs(cache_dir, exist_ok=True) + whitelist_file = os.path.basename(whitelist) if whitelist else "" + far_file = os.path.join( + cache_dir, f"en_tn_{deterministic}_deterministic_{input_case}_{whitelist_file}_tokenize.far" + ) + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode="r")["tokenize_and_classify"] + else: + + start_time = time.time() + cardinal = CardinalFst(deterministic=deterministic) + cardinal_graph = cardinal.fst + + start_time = time.time() + ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic) + ordinal_graph = ordinal.fst + + start_time = time.time() + decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic) + decimal_graph = decimal.fst + + start_time = time.time() + fraction = FractionFst(deterministic=deterministic, cardinal=cardinal) + fraction_graph = fraction.fst + + start_time = time.time() + measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic) + measure_graph = measure.fst + + start_time = time.time() + date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst + + start_time = time.time() + time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst + + start_time = time.time() + telephone_graph = TelephoneFst(deterministic=deterministic).fst + + start_time = time.time() + electonic_graph = ElectronicFst(deterministic=deterministic).fst + + start_time = time.time() + money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst + + + start_time = time.time() + whitelist_graph = WhiteListFst( + input_case=input_case, deterministic=deterministic, input_file=whitelist + ).fst + + start_time = time.time() + punctuation = PunctuationFst(deterministic=deterministic) + punct_graph = punctuation.fst + + start_time = time.time() + word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).fst + + + start_time = time.time() + serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic).fst + + + start_time = time.time() + v_time_graph = vTimeFst(deterministic=deterministic).fst + v_ordinal_graph = vOrdinalFst(deterministic=deterministic) + v_date_graph = vDateFst(ordinal=v_ordinal_graph, deterministic=deterministic).fst + time_final = pynini.compose(time_graph, v_time_graph) + date_final = pynini.compose(date_graph, v_date_graph) + range_graph = RangeFst( + time=time_final, date=date_final, cardinal=cardinal, deterministic=deterministic + ).fst + + + classify = ( + pynutil.add_weight(whitelist_graph, 1.01) + | pynutil.add_weight(time_graph, 1.1) + | pynutil.add_weight(date_graph, 1.09) + | pynutil.add_weight(decimal_graph, 1.1) + | pynutil.add_weight(measure_graph, 1.1) + | pynutil.add_weight(cardinal_graph, 1.1) + | pynutil.add_weight(ordinal_graph, 1.1) + | pynutil.add_weight(money_graph, 1.1) + | pynutil.add_weight(telephone_graph, 1.1) + | pynutil.add_weight(electonic_graph, 1.1) + | pynutil.add_weight(fraction_graph, 1.1) + | pynutil.add_weight(range_graph, 1.1) + | pynutil.add_weight(serial_graph, 1.1001) # should be higher than the rest of the classes + ) + + roman_graph = RomanFst(deterministic=deterministic).fst + classify |= pynutil.add_weight(roman_graph, 1.1) + + if not deterministic: + abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst + classify |= pynutil.add_weight(abbreviation_graph, 100) + + punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=2.1) + pynutil.insert(" }") + punct = pynini.closure( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct), + 1, + ) + + classify |= pynutil.add_weight(word_graph, 100) + token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }") + token_plus_punct = ( + pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct) + ) + + graph = token_plus_punct + pynini.closure( + ( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct + pynutil.insert(" ")) + ) + + token_plus_punct + ) + + graph = delete_space + graph + delete_space + graph |= punct + + self.fst = graph.optimize() + + if far_file: + generator_main(far_file, {"tokenize_and_classify": self.fst}) + diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_lm.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_lm.py new file mode 100644 index 0000000..fa48c37 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_lm.py @@ -0,0 +1,228 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_CHAR, + NEMO_DIGIT, + NEMO_NOT_SPACE, + NEMO_SIGMA, + NEMO_WHITE_SPACE, + GraphFst, + delete_extra_space, + delete_space, + generator_main, +) +from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst +from nemo_text_processing.text_normalization.en.taggers.date import DateFst +from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst +from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst +from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst +from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst +from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst +from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst +from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst +from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst +from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst +from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst +from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst +from nemo_text_processing.text_normalization.en.taggers.time import TimeFst +from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst +from nemo_text_processing.text_normalization.en.taggers.word import WordFst +from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal +from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate +from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal +from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic +from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction +from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure +from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal +from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman +from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone +from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime +from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord +from pynini.examples import plurals +from pynini.lib import pynutil + +from nemo.utils import logging + + +class ClassifyFst(GraphFst): + """ + Final class that composes all other classification grammars. This class can process an entire sentence including punctuation. + For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File. + More details to deployment at NeMo/tools/text_processing_deployment. + + Args: + input_case: accepting either "lower_cased" or "cased" input. + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + whitelist: path to a file with whitelist replacements + """ + + def __init__( + self, + input_case: str, + deterministic: bool = True, + cache_dir: str = None, + overwrite_cache: bool = True, + whitelist: str = None, + ): + super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic) + + far_file = None + if cache_dir is not None and cache_dir != 'None': + os.makedirs(cache_dir, exist_ok=True) + whitelist_file = os.path.basename(whitelist) if whitelist else "" + far_file = os.path.join( + cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}_lm.far" + ) + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify'] + no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT)) + self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize() + logging.info(f'ClassifyFst.fst was restored from {far_file}.') + else: + logging.info(f'Creating ClassifyFst grammars. This might take some time...') + # TAGGERS + cardinal = CardinalFst(deterministic=True, lm=True) + cardinal_tagger = cardinal + cardinal_graph = cardinal.fst + + ordinal = OrdinalFst(cardinal=cardinal, deterministic=True) + ordinal_graph = ordinal.fst + + decimal = DecimalFst(cardinal=cardinal, deterministic=True) + decimal_graph = decimal.fst + fraction = FractionFst(deterministic=True, cardinal=cardinal) + fraction_graph = fraction.fst + + measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=True) + measure_graph = measure.fst + date = DateFst(cardinal=cardinal, deterministic=True, lm=True) + date_graph = date.fst + punctuation = PunctuationFst(deterministic=True) + punct_graph = punctuation.graph + word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph + time_graph = TimeFst(cardinal=cardinal, deterministic=True).fst + telephone_graph = TelephoneFst(deterministic=True).fst + electronic_graph = ElectronicFst(deterministic=True).fst + money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=False).fst + whitelist = WhiteListFst(input_case=input_case, deterministic=False, input_file=whitelist) + whitelist_graph = whitelist.graph + serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic, lm=True).fst + + # VERBALIZERS + cardinal = vCardinal(deterministic=True) + v_cardinal_graph = cardinal.fst + decimal = vDecimal(cardinal=cardinal, deterministic=True) + v_decimal_graph = decimal.fst + ordinal = vOrdinal(deterministic=True) + v_ordinal_graph = ordinal.fst + fraction = vFraction(deterministic=True, lm=True) + v_fraction_graph = fraction.fst + v_telephone_graph = vTelephone(deterministic=True).fst + v_electronic_graph = vElectronic(deterministic=True).fst + measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=False) + v_measure_graph = measure.fst + v_time_graph = vTime(deterministic=True).fst + v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic, lm=True).fst + v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst + v_roman_graph = vRoman(deterministic=deterministic).fst + v_word_graph = vWord(deterministic=deterministic).fst + + cardinal_or_date_final = plurals._priority_union(date_graph, cardinal_graph, NEMO_SIGMA) + cardinal_or_date_final = pynini.compose(cardinal_or_date_final, (v_cardinal_graph | v_date_graph)) + + time_final = pynini.compose(time_graph, v_time_graph) + ordinal_final = pynini.compose(ordinal_graph, v_ordinal_graph) + sem_w = 1 + word_w = 100 + punct_w = 2 + classify_and_verbalize = ( + pynutil.add_weight(time_final, sem_w) + | pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w) + | pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w) + | pynutil.add_weight(ordinal_final, sem_w) + | pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w) + | pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w) + | pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w) + | pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w) + | pynutil.add_weight(cardinal_or_date_final, sem_w) + | pynutil.add_weight(whitelist_graph, sem_w) + | pynutil.add_weight( + pynini.compose(serial_graph, v_word_graph), 1.1001 + ) # should be higher than the rest of the classes + ).optimize() + + roman_graph = RomanFst(deterministic=deterministic, lm=True).fst + # the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens + classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), sem_w) + + date_final = pynini.compose(date_graph, v_date_graph) + range_graph = RangeFst( + time=time_final, cardinal=cardinal_tagger, date=date_final, deterministic=deterministic + ).fst + classify_and_verbalize |= pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w) + classify_and_verbalize = pynutil.insert("< ") + classify_and_verbalize + pynutil.insert(" >") + classify_and_verbalize |= pynutil.add_weight(word_graph, word_w) + + punct_only = pynutil.add_weight(punct_graph, weight=punct_w) + punct = pynini.closure( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct_only), + 1, + ) + + def get_token_sem_graph(classify_and_verbalize): + token_plus_punct = ( + pynini.closure(punct + pynutil.insert(" ")) + + classify_and_verbalize + + pynini.closure(pynutil.insert(" ") + punct) + ) + + graph = token_plus_punct + pynini.closure( + ( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct + pynutil.insert(" ")) + ) + + token_plus_punct + ) + + graph |= punct_only + pynini.closure(punct) + graph = delete_space + graph + delete_space + + remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure( + delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1) + ) + remove_extra_spaces |= ( + pynini.closure(pynutil.delete(" "), 1) + + pynini.closure(NEMO_NOT_SPACE, 1) + + pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)) + ) + + graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize() + return graph + + self.fst = get_token_sem_graph(classify_and_verbalize) + no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT)) + self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize() + + if far_file: + generator_main(far_file, {"tokenize_and_classify": self.fst}) + logging.info(f'ClassifyFst grammars are saved to {far_file}.') diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_with_audio.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_with_audio.py new file mode 100644 index 0000000..d9adc4c --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/tokenize_and_classify_with_audio.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_CHAR, + NEMO_DIGIT, + NEMO_NOT_SPACE, + NEMO_WHITE_SPACE, + GraphFst, + delete_extra_space, + delete_space, + generator_main, +) +from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst +from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst +from nemo_text_processing.text_normalization.en.taggers.date import DateFst +from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst +from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst +from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst +from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst +from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst +from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst +from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst +from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst +from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst +from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst +from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst +from nemo_text_processing.text_normalization.en.taggers.time import TimeFst +from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst +from nemo_text_processing.text_normalization.en.taggers.word import WordFst +from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst as vAbbreviation +from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal +from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate +from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal +from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic +from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction +from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure +from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal +from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman +from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone +from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime +from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord +from pynini.lib import pynutil + +from nemo.utils import logging + + +class ClassifyFst(GraphFst): + """ + Final class that composes all other classification grammars. This class can process an entire sentence including punctuation. + For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File. + More details to deployment at NeMo/tools/text_processing_deployment. + + Args: + input_case: accepting either "lower_cased" or "cased" input. + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + whitelist: path to a file with whitelist replacements + """ + + def __init__( + self, + input_case: str, + deterministic: bool = True, + cache_dir: str = None, + overwrite_cache: bool = True, + whitelist: str = None, + ): + super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic) + + far_file = None + if cache_dir is not None and cache_dir != 'None': + os.makedirs(cache_dir, exist_ok=True) + whitelist_file = os.path.basename(whitelist) if whitelist else "" + far_file = os.path.join( + cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}.far" + ) + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify'] + no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT)) + self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize() + logging.info(f'ClassifyFst.fst was restored from {far_file}.') + else: + logging.info(f'Creating ClassifyFst grammars. This might take some time...') + # TAGGERS + cardinal = CardinalFst(deterministic=deterministic) + cardinal_graph = cardinal.fst + + ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic) + deterministic_ordinal = OrdinalFst(cardinal=cardinal, deterministic=True) + ordinal_graph = ordinal.fst + + decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic) + decimal_graph = decimal.fst + fraction = FractionFst(deterministic=deterministic, cardinal=cardinal) + fraction_graph = fraction.fst + + measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic) + measure_graph = measure.fst + date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst + punctuation = PunctuationFst(deterministic=True) + punct_graph = punctuation.graph + word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph + time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst + telephone_graph = TelephoneFst(deterministic=deterministic).fst + electronic_graph = ElectronicFst(deterministic=deterministic).fst + money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst + whitelist = WhiteListFst(input_case=input_case, deterministic=deterministic, input_file=whitelist) + whitelist_graph = whitelist.graph + serial_graph = SerialFst(cardinal=cardinal, ordinal=deterministic_ordinal, deterministic=deterministic).fst + + # VERBALIZERS + cardinal = vCardinal(deterministic=deterministic) + v_cardinal_graph = cardinal.fst + decimal = vDecimal(cardinal=cardinal, deterministic=deterministic) + v_decimal_graph = decimal.fst + ordinal = vOrdinal(deterministic=deterministic) + v_ordinal_graph = ordinal.fst + fraction = vFraction(deterministic=deterministic) + v_fraction_graph = fraction.fst + v_telephone_graph = vTelephone(deterministic=deterministic).fst + v_electronic_graph = vElectronic(deterministic=deterministic).fst + measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic) + v_measure_graph = measure.fst + v_time_graph = vTime(deterministic=deterministic).fst + v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic).fst + v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst + v_roman_graph = vRoman(deterministic=deterministic).fst + v_abbreviation = vAbbreviation(deterministic=deterministic).fst + + det_v_time_graph = vTime(deterministic=True).fst + det_v_date_graph = vDate(ordinal=vOrdinal(deterministic=True), deterministic=True).fst + time_final = pynini.compose(time_graph, det_v_time_graph) + date_final = pynini.compose(date_graph, det_v_date_graph) + range_graph = RangeFst( + time=time_final, date=date_final, cardinal=CardinalFst(deterministic=True), deterministic=deterministic + ).fst + v_word_graph = vWord(deterministic=deterministic).fst + + sem_w = 1 + word_w = 100 + punct_w = 2 + classify_and_verbalize = ( + pynutil.add_weight(whitelist_graph, sem_w) + | pynutil.add_weight(pynini.compose(time_graph, v_time_graph), sem_w) + | pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w) + | pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w) + | pynutil.add_weight(pynini.compose(cardinal_graph, v_cardinal_graph), sem_w) + | pynutil.add_weight(pynini.compose(ordinal_graph, v_ordinal_graph), sem_w) + | pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w) + | pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w) + | pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w) + | pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w) + | pynutil.add_weight(word_graph, word_w) + | pynutil.add_weight(pynini.compose(date_graph, v_date_graph), sem_w - 0.01) + | pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w) + | pynutil.add_weight( + pynini.compose(serial_graph, v_word_graph), 1.1001 + ) # should be higher than the rest of the classes + ).optimize() + + if not deterministic: + roman_graph = RomanFst(deterministic=deterministic).fst + # the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens + classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), word_w) + + abbreviation_graph = AbbreviationFst(whitelist=whitelist, deterministic=deterministic).fst + classify_and_verbalize |= pynutil.add_weight( + pynini.compose(abbreviation_graph, v_abbreviation), word_w + ) + + punct_only = pynutil.add_weight(punct_graph, weight=punct_w) + punct = pynini.closure( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct_only), + 1, + ) + + token_plus_punct = ( + pynini.closure(punct + pynutil.insert(" ")) + + classify_and_verbalize + + pynini.closure(pynutil.insert(" ") + punct) + ) + + graph = token_plus_punct + pynini.closure( + ( + pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space) + | (pynutil.insert(" ") + punct + pynutil.insert(" ")) + ) + + token_plus_punct + ) + + graph |= punct_only + pynini.closure(punct) + graph = delete_space + graph + delete_space + + remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure( + delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1) + ) + remove_extra_spaces |= ( + pynini.closure(pynutil.delete(" "), 1) + + pynini.closure(NEMO_NOT_SPACE, 1) + + pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)) + ) + + graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize() + self.fst = graph + no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT)) + self.fst_no_digits = pynini.compose(graph, no_digits).optimize() + + if far_file: + generator_main(far_file, {"tokenize_and_classify": self.fst}) + logging.info(f'ClassifyFst grammars are saved to {far_file}.') diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/whitelist.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/whitelist.py new file mode 100644 index 0000000..54c5b53 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/whitelist.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_CHAR, + NEMO_NOT_SPACE, + NEMO_SIGMA, + NEMO_UPPER, + SINGULAR_TO_PLURAL, + GraphFst, + convert_space, +) +from nemo_text_processing.text_normalization.en.taggers.roman import get_names +from nemo_text_processing.text_normalization.en.utils import ( + augment_labels_with_punct_at_end, + get_abs_path, + load_labels, +) +from pynini.lib import pynutil + + +class WhiteListFst(GraphFst): + """ + Finite state transducer for classifying whitelist, e.g. + misses -> tokens { name: "mrs" } + for non-deterministic case: "Dr. Abc" -> + tokens { name: "drive" } tokens { name: "Abc" } + tokens { name: "doctor" } tokens { name: "Abc" } + tokens { name: "Dr." } tokens { name: "Abc" } + This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv". + + Args: + input_case: accepting either "lower_cased" or "cased" input. + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + input_file: path to a file with whitelist replacements + """ + + def __init__(self, input_case: str, deterministic: bool = True, input_file: str = None): + super().__init__(name="whitelist", kind="classify", deterministic=deterministic) + + def _get_whitelist_graph(input_case, file, keep_punct_add_end: bool = False): + whitelist = load_labels(file) + if input_case == "lower_cased": + whitelist = [[x.lower(), y] for x, y in whitelist] + else: + whitelist = [[x, y] for x, y in whitelist] + + if keep_punct_add_end: + whitelist.extend(augment_labels_with_punct_at_end(whitelist)) + + graph = pynini.string_map(whitelist) + return graph + + graph = _get_whitelist_graph(input_case, get_abs_path("data/whitelist/tts.tsv")) + graph |= _get_whitelist_graph(input_case, get_abs_path("data/whitelist/UK_to_US.tsv")) # Jiayu 2022.10 + graph |= pynini.compose( + pynini.difference(NEMO_SIGMA, pynini.accep("/")).optimize(), + _get_whitelist_graph(input_case, get_abs_path("data/whitelist/symbol.tsv")), + ).optimize() + + if deterministic: + names = get_names() + graph |= ( + pynini.cross(pynini.union("st", "St", "ST"), "Saint") + + pynini.closure(pynutil.delete(".")) + + pynini.accep(" ") + + names + ) + else: + graph |= _get_whitelist_graph( + input_case, get_abs_path("data/whitelist/alternatives.tsv"), keep_punct_add_end=True + ) + + for x in [".", ". "]: + graph |= ( + NEMO_UPPER + + pynini.closure(pynutil.delete(x) + NEMO_UPPER, 2) + + pynini.closure(pynutil.delete("."), 0, 1) + ) + + if not deterministic: + multiple_forms_whitelist_graph = get_formats(get_abs_path("data/whitelist/alternatives_all_format.tsv")) + graph |= multiple_forms_whitelist_graph + + graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv")) | pynini.string_file( + get_abs_path("data/measure/unit_alternatives.tsv") + ) + graph_unit_plural = graph_unit @ SINGULAR_TO_PLURAL + units_graph = pynini.compose(NEMO_CHAR ** (3, ...), convert_space(graph_unit | graph_unit_plural)) + graph |= units_graph + + # convert to states only if comma is present before the abbreviation to avoid converting all caps words, + # e.g. "IN", "OH", "OK" + # TODO or only exclude above? + states = load_labels(get_abs_path("data/address/state.tsv")) + additional_options = [] + for x, y in states: + if input_case == "lower_cased": + x = x.lower() + additional_options.append((x, f"{y[0]}.{y[1:]}")) + if not deterministic: + additional_options.append((x, f"{y[0]}.{y[1:]}.")) + + states.extend(additional_options) + state_graph = pynini.string_map(states) + graph |= pynini.closure(NEMO_NOT_SPACE, 1) + pynini.union(", ", ",") + pynini.invert(state_graph).optimize() + + if input_file: + whitelist_provided = _get_whitelist_graph(input_case, input_file) + if not deterministic: + graph |= whitelist_provided + else: + graph = whitelist_provided + + self.graph = (convert_space(graph)).optimize() + + self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize() + + +def get_formats(input_f, input_case="cased", is_default=True): + """ + Adds various abbreviation format options to the list of acceptable input forms + """ + multiple_formats = load_labels(input_f) + additional_options = [] + for x, y in multiple_formats: + if input_case == "lower_cased": + x = x.lower() + additional_options.append((f"{x}.", y)) # default "dr" -> doctor, this includes period "dr." -> doctor + additional_options.append((f"{x[0].upper() + x[1:]}", f"{y[0].upper() + y[1:]}")) # "Dr" -> Doctor + additional_options.append((f"{x[0].upper() + x[1:]}.", f"{y[0].upper() + y[1:]}")) # "Dr." -> Doctor + multiple_formats.extend(additional_options) + + if not is_default: + multiple_formats = [(x, f"|raw_start|{x}|raw_end||norm_start|{y}|norm_end|") for (x, y) in multiple_formats] + + multiple_formats = pynini.string_map(multiple_formats) + return multiple_formats diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/taggers/word.py b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/word.py new file mode 100644 index 0000000..fa6a965 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/taggers/word.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + MIN_NEG_WEIGHT, + NEMO_ALPHA, + NEMO_DIGIT, + NEMO_NOT_SPACE, + NEMO_SIGMA, + GraphFst, + convert_space, + get_abs_path, +) +from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst +from pynini.examples import plurals +from pynini.lib import pynutil + + +class WordFst(GraphFst): + """ + Finite state transducer for classifying word. Considers sentence boundary exceptions. + e.g. sleep -> tokens { name: "sleep" } + + Args: + punctuation: PunctuationFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, punctuation: GraphFst, deterministic: bool = True): + super().__init__(name="word", kind="classify", deterministic=deterministic) + + punct = PunctuationFst().graph + default_graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, punct.project("input")), 1) + symbols_to_exclude = (pynini.union("$", "€", "₩", "£", "¥", "#", "%") | NEMO_DIGIT).optimize() + graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, symbols_to_exclude), 1) + graph = pynutil.add_weight(graph, MIN_NEG_WEIGHT) | default_graph + + # leave phones of format [HH AH0 L OW1] untouched + phoneme_unit = pynini.closure(NEMO_ALPHA, 1) + pynini.closure(NEMO_DIGIT) + phoneme = ( + pynini.accep(pynini.escape("[")) + + pynini.closure(phoneme_unit + pynini.accep(" ")) + + phoneme_unit + + pynini.accep(pynini.escape("]")) + ) + + # leave IPA phones of format [ˈdoʊv] untouched, single words and sentences with punctuation marks allowed + punct_marks = pynini.union(*punctuation.punct_marks).optimize() + stress = pynini.union("ˈ", "'", "ˌ") + ipa_phoneme_unit = pynini.string_file(get_abs_path("data/whitelist/ipa_symbols.tsv")) + # word in ipa form + ipa_phonemes = ( + pynini.closure(stress, 0, 1) + + pynini.closure(ipa_phoneme_unit, 1) + + pynini.closure(stress | ipa_phoneme_unit) + ) + # allow sentences of words in IPA format separated with spaces or punct marks + delim = (punct_marks | pynini.accep(" ")) ** (1, ...) + ipa_phonemes = ipa_phonemes + pynini.closure(delim + ipa_phonemes) + pynini.closure(delim, 0, 1) + ipa_phonemes = (pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]"))).optimize() + + if not deterministic: + phoneme = ( + pynini.accep(pynini.escape("[")) + + pynini.closure(pynini.accep(" "), 0, 1) + + pynini.closure(phoneme_unit + pynini.accep(" ")) + + phoneme_unit + + pynini.closure(pynini.accep(" "), 0, 1) + + pynini.accep(pynini.escape("]")) + ).optimize() + ipa_phonemes = ( + pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]")) + ).optimize() + + phoneme |= ipa_phonemes + self.graph = plurals._priority_union(convert_space(phoneme.optimize()), graph, NEMO_SIGMA) + self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/utils.py b/utils/speechio/nemo_text_processing/text_normalization/en/utils.py new file mode 100644 index 0000000..3a88fd8 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os + + +def get_abs_path(rel_path): + """ + Get absolute path + + Args: + rel_path: relative path to this file + + Returns absolute path + """ + return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path + + +def load_labels(abs_path): + """ + loads relative path file as dictionary + + Args: + abs_path: absolute path + + Returns dictionary of mappings + """ + label_tsv = open(abs_path, encoding="utf-8") + labels = list(csv.reader(label_tsv, delimiter="\t")) + return labels + + +def augment_labels_with_punct_at_end(labels): + """ + augments labels: if key ends on a punctuation that value does not have, add a new label + where the value maintains the punctuation + + Args: + labels : input labels + Returns: + additional labels + """ + res = [] + for label in labels: + if len(label) > 1: + if label[0][-1] == "." and label[1][-1] != ".": + res.append([label[0], label[1] + "."] + label[2:]) + return res diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/__init__.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/__init__.py new file mode 100644 index 0000000..bc443be --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/abbreviation.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/abbreviation.py new file mode 100644 index 0000000..1917924 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/abbreviation.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst +from pynini.lib import pynutil + + +class AbbreviationFst(GraphFst): + """ + Finite state transducer for verbalizing abbreviations + e.g. tokens { abbreviation { value: "A B C" } } -> "ABC" + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="abbreviation", kind="verbalize", deterministic=deterministic) + + graph = pynutil.delete("value: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"") + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/cardinal.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/cardinal.py new file mode 100644 index 0000000..99531a0 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/cardinal.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space +from pynini.lib import pynutil + + +class CardinalFst(GraphFst): + """ + Finite state transducer for verbalizing cardinal, e.g. + cardinal { negative: "true" integer: "23" } -> minus twenty three + + Args: + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="cardinal", kind="verbalize", deterministic=deterministic) + + self.optional_sign = pynini.cross("negative: \"true\"", "minus ") + if not deterministic: + self.optional_sign |= pynini.cross("negative: \"true\"", "negative ") + self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1) + + integer = pynini.closure(NEMO_NOT_QUOTE) + + self.integer = delete_space + pynutil.delete("\"") + integer + pynutil.delete("\"") + integer = pynutil.delete("integer:") + self.integer + + self.numbers = self.optional_sign + integer + delete_tokens = self.delete_tokens(self.numbers) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/date.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/date.py new file mode 100644 index 0000000..191d010 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/date.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_NOT_QUOTE, + NEMO_SIGMA, + GraphFst, + delete_extra_space, + delete_space, +) +from pynini.examples import plurals +from pynini.lib import pynutil + + +class DateFst(GraphFst): + """ + Finite state transducer for verbalizing date, e.g. + date { month: "february" day: "five" year: "twenty twelve" preserve_order: true } -> february fifth twenty twelve + date { day: "five" month: "february" year: "twenty twelve" preserve_order: true } -> the fifth of february twenty twelve + + Args: + ordinal: OrdinalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, ordinal: GraphFst, deterministic: bool = True, lm: bool = False): + super().__init__(name="date", kind="verbalize", deterministic=deterministic) + + month = pynini.closure(NEMO_NOT_QUOTE, 1) + day_cardinal = ( + pynutil.delete("day:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + day = day_cardinal @ ordinal.suffix + + month = pynutil.delete("month:") + delete_space + pynutil.delete("\"") + month + pynutil.delete("\"") + + year = ( + pynutil.delete("year:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + delete_space + + pynutil.delete("\"") + ) + + # month (day) year + graph_mdy = ( + month + pynini.closure(delete_extra_space + day, 0, 1) + pynini.closure(delete_extra_space + year, 0, 1) + ) + # may 5 -> may five + if not deterministic and not lm: + graph_mdy |= ( + month + + pynini.closure(delete_extra_space + day_cardinal, 0, 1) + + pynini.closure(delete_extra_space + year, 0, 1) + ) + + # day month year + graph_dmy = ( + pynutil.insert("the ") + + day + + delete_extra_space + + pynutil.insert("of ") + + month + + pynini.closure(delete_extra_space + year, 0, 1) + ) + + optional_preserve_order = pynini.closure( + pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space + | pynutil.delete("field_order:") + + delete_space + + pynutil.delete("\"") + + NEMO_NOT_QUOTE + + pynutil.delete("\"") + + delete_space + ) + + final_graph = ( + (plurals._priority_union(graph_mdy, pynutil.add_weight(graph_dmy, 0.0001), NEMO_SIGMA) | year) + + delete_space + + optional_preserve_order + ) + delete_tokens = self.delete_tokens(final_graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/decimal.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/decimal.py new file mode 100644 index 0000000..787bcea --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/decimal.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space +from pynini.lib import pynutil + + +class DecimalFst(GraphFst): + """ + Finite state transducer for verbalizing decimal, e.g. + decimal { negative: "true" integer_part: "twelve" fractional_part: "five o o six" quantity: "billion" } -> minus twelve point five o o six billion + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, cardinal, deterministic: bool = True): + super().__init__(name="decimal", kind="verbalize", deterministic=deterministic) + self.optional_sign = pynini.cross("negative: \"true\"", "minus ") + if not deterministic: + self.optional_sign |= pynini.cross("negative: \"true\"", "negative ") + self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1) + self.integer = pynutil.delete("integer_part:") + cardinal.integer + self.optional_integer = pynini.closure(self.integer + delete_space + insert_space, 0, 1) + self.fractional_default = ( + pynutil.delete("fractional_part:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + + self.fractional = pynutil.insert("point ") + self.fractional_default + + self.quantity = ( + delete_space + + insert_space + + pynutil.delete("quantity:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + self.optional_quantity = pynini.closure(self.quantity, 0, 1) + + graph = self.optional_sign + ( + self.integer + | (self.integer + self.quantity) + | (self.optional_integer + self.fractional + self.optional_quantity) + ) + + self.numbers = graph + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/electronic.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/electronic.py new file mode 100644 index 0000000..884f125 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/electronic.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_NOT_QUOTE, + NEMO_NOT_SPACE, + NEMO_SIGMA, + TO_UPPER, + GraphFst, + delete_extra_space, + delete_space, + insert_space, +) +from nemo_text_processing.text_normalization.en.utils import get_abs_path +from pynini.examples import plurals +from pynini.lib import pynutil + + +class ElectronicFst(GraphFst): + """ + Finite state transducer for verbalizing electronic + e.g. tokens { electronic { username: "cdf1" domain: "abc.edu" } } -> c d f one at a b c dot e d u + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="electronic", kind="verbalize", deterministic=deterministic) + graph_digit_no_zero = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize() + graph_zero = pynini.cross("0", "zero") + + if not deterministic: + graph_zero |= pynini.cross("0", "o") | pynini.cross("0", "oh") + + graph_digit = graph_digit_no_zero | graph_zero + graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize() + + default_chars_symbols = pynini.cdrewrite( + pynutil.insert(" ") + (graph_symbols | graph_digit) + pynutil.insert(" "), "", "", NEMO_SIGMA + ) + default_chars_symbols = pynini.compose( + pynini.closure(NEMO_NOT_SPACE), default_chars_symbols.optimize() + ).optimize() + + user_name = ( + pynutil.delete("username:") + + delete_space + + pynutil.delete("\"") + + default_chars_symbols + + pynutil.delete("\"") + ) + + domain_common = pynini.string_file(get_abs_path("data/electronic/domain.tsv")) + + domain = ( + default_chars_symbols + + insert_space + + plurals._priority_union( + domain_common, pynutil.add_weight(pynini.cross(".", "dot"), weight=0.0001), NEMO_SIGMA + ) + + pynini.closure( + insert_space + (pynini.cdrewrite(TO_UPPER, "", "", NEMO_SIGMA) @ default_chars_symbols), 0, 1 + ) + ) + domain = ( + pynutil.delete("domain:") + + delete_space + + pynutil.delete("\"") + + domain + + delete_space + + pynutil.delete("\"") + ).optimize() + + protocol = pynutil.delete("protocol: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"") + graph = ( + pynini.closure(protocol + delete_space, 0, 1) + + pynini.closure(user_name + delete_space + pynutil.insert(" at ") + delete_space, 0, 1) + + domain + + delete_space + ).optimize() @ pynini.cdrewrite(delete_extra_space, "", "", NEMO_SIGMA) + + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/fraction.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/fraction.py new file mode 100644 index 0000000..d0c5dc2 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/fraction.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, insert_space +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst +from pynini.examples import plurals +from pynini.lib import pynutil + + +class FractionFst(GraphFst): + """ + Finite state transducer for verbalizing fraction + e.g. tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } } -> + twenty three and four fifth + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True, lm: bool = False): + super().__init__(name="fraction", kind="verbalize", deterministic=deterministic) + suffix = OrdinalFst().suffix + + integer = pynutil.delete("integer_part: \"") + pynini.closure(NEMO_NOT_QUOTE) + pynutil.delete("\" ") + denominator_one = pynini.cross("denominator: \"one\"", "over one") + denominator_half = pynini.cross("denominator: \"two\"", "half") + denominator_quarter = pynini.cross("denominator: \"four\"", "quarter") + + denominator_rest = ( + pynutil.delete("denominator: \"") + pynini.closure(NEMO_NOT_QUOTE) @ suffix + pynutil.delete("\"") + ) + + denominators = plurals._priority_union( + denominator_one, + plurals._priority_union( + denominator_half, + plurals._priority_union(denominator_quarter, denominator_rest, NEMO_SIGMA), + NEMO_SIGMA, + ), + NEMO_SIGMA, + ).optimize() + if not deterministic: + denominators |= pynutil.delete("denominator: \"") + (pynini.accep("four") @ suffix) + pynutil.delete("\"") + + numerator_one = pynutil.delete("numerator: \"") + pynini.accep("one") + pynutil.delete("\" ") + numerator_one = numerator_one + insert_space + denominators + numerator_rest = ( + pynutil.delete("numerator: \"") + + (pynini.closure(NEMO_NOT_QUOTE) - pynini.accep("one")) + + pynutil.delete("\" ") + ) + numerator_rest = numerator_rest + insert_space + denominators + numerator_rest @= pynini.cdrewrite( + plurals._priority_union(pynini.cross("half", "halves"), pynutil.insert("s"), NEMO_SIGMA), + "", + "[EOS]", + NEMO_SIGMA, + ) + + graph = numerator_one | numerator_rest + + conjunction = pynutil.insert("and ") + if not deterministic and not lm: + conjunction = pynini.closure(conjunction, 0, 1) + + integer = pynini.closure(integer + insert_space + conjunction, 0, 1) + + graph = integer + graph + graph @= pynini.cdrewrite( + pynini.cross("and one half", "and a half") | pynini.cross("over ones", "over one"), "", "[EOS]", NEMO_SIGMA + ) + + self.graph = graph + delete_tokens = self.delete_tokens(self.graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/measure.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/measure.py new file mode 100644 index 0000000..e4a23b3 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/measure.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space +from pynini.lib import pynutil + + +class MeasureFst(GraphFst): + """ + Finite state transducer for verbalizing measure, e.g. + measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" } -> minus twelve kilograms + measure { decimal { integer_part: "twelve" fractional_part: "five" } units: "kilograms" } -> twelve point five kilograms + tokens { measure { units: "covid" decimal { integer_part: "nineteen" fractional_part: "five" } } } -> covid nineteen point five + + Args: + decimal: DecimalFst + cardinal: CardinalFst + fraction: FractionFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, decimal: GraphFst, cardinal: GraphFst, fraction: GraphFst, deterministic: bool = True): + super().__init__(name="measure", kind="verbalize", deterministic=deterministic) + optional_sign = cardinal.optional_sign + unit = ( + pynutil.delete("units: \"") + + pynini.difference(pynini.closure(NEMO_NOT_QUOTE, 1), pynini.union("address", "math")) + + pynutil.delete("\"") + + delete_space + ) + + if not deterministic: + unit |= pynini.compose(unit, pynini.cross(pynini.union("inch", "inches"), "\"")) + + graph_decimal = ( + pynutil.delete("decimal {") + + delete_space + + optional_sign + + delete_space + + decimal.numbers + + delete_space + + pynutil.delete("}") + ) + graph_cardinal = ( + pynutil.delete("cardinal {") + + delete_space + + optional_sign + + delete_space + + cardinal.numbers + + delete_space + + pynutil.delete("}") + ) + + graph_fraction = ( + pynutil.delete("fraction {") + delete_space + fraction.graph + delete_space + pynutil.delete("}") + ) + + graph = (graph_cardinal | graph_decimal | graph_fraction) + delete_space + insert_space + unit + + # SH adds "preserve_order: true" by default + preserve_order = pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space + graph |= unit + insert_space + (graph_cardinal | graph_decimal) + delete_space + pynini.closure(preserve_order) + # for only unit + graph |= ( + pynutil.delete("cardinal { integer: \"-\"") + + delete_space + + pynutil.delete("}") + + delete_space + + unit + + pynini.closure(preserve_order) + ) + address = ( + pynutil.delete("units: \"address\" ") + + delete_space + + graph_cardinal + + delete_space + + pynini.closure(preserve_order) + ) + math = ( + pynutil.delete("units: \"math\" ") + + delete_space + + graph_cardinal + + delete_space + + pynini.closure(preserve_order) + ) + graph |= address | math + + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/money.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/money.py new file mode 100644 index 0000000..b3cbc4a --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/money.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_NOT_QUOTE, + GraphFst, + delete_extra_space, + delete_preserve_order, +) +from pynini.lib import pynutil + + +class MoneyFst(GraphFst): + """ + Finite state transducer for verbalizing money, e.g. + money { integer_part: "twelve" fractional_part: "o five" currency: "dollars" } -> twelve o five dollars + + Args: + decimal: DecimalFst + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, decimal: GraphFst, deterministic: bool = True): + super().__init__(name="money", kind="verbalize", deterministic=deterministic) + keep_space = pynini.accep(" ") + maj = pynutil.delete("currency_maj: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"") + min = pynutil.delete("currency_min: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"") + + fractional_part = ( + pynutil.delete("fractional_part: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"") + ) + + integer_part = decimal.integer + + # *** currency_maj + graph_integer = integer_part + keep_space + maj + + # *** currency_maj + (***) | ((and) *** current_min) + fractional = fractional_part + delete_extra_space + min + + if not deterministic: + fractional |= pynutil.insert("and ") + fractional + + graph_integer_with_minor = integer_part + keep_space + maj + keep_space + fractional + delete_preserve_order + + # *** point *** currency_maj + graph_decimal = decimal.numbers + keep_space + maj + + # *** current_min + graph_minor = fractional_part + delete_extra_space + min + delete_preserve_order + + graph = graph_integer | graph_integer_with_minor | graph_decimal | graph_minor + + if not deterministic: + graph |= graph_integer + delete_preserve_order + + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/ordinal.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/ordinal.py new file mode 100644 index 0000000..c64579a --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/ordinal.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, delete_space +from nemo_text_processing.text_normalization.en.utils import get_abs_path +from pynini.lib import pynutil + + +class OrdinalFst(GraphFst): + """ + Finite state transducer for verbalizing ordinal, e.g. + ordinal { integer: "thirteen" } } -> thirteenth + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="ordinal", kind="verbalize", deterministic=deterministic) + + graph_digit = pynini.string_file(get_abs_path("data/ordinal/digit.tsv")).invert() + graph_teens = pynini.string_file(get_abs_path("data/ordinal/teen.tsv")).invert() + + graph = ( + pynutil.delete("integer:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + convert_rest = pynutil.insert("th") + + suffix = pynini.cdrewrite( + graph_digit | graph_teens | pynini.cross("ty", "tieth") | convert_rest, "", "[EOS]", NEMO_SIGMA, + ).optimize() + self.graph = pynini.compose(graph, suffix) + self.suffix = suffix + delete_tokens = self.delete_tokens(self.graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/post_processing.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/post_processing.py new file mode 100644 index 0000000..6c87da1 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/post_processing.py @@ -0,0 +1,180 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + MIN_NEG_WEIGHT, + NEMO_ALPHA, + NEMO_CHAR, + NEMO_SIGMA, + NEMO_SPACE, + generator_main, +) +from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst +from pynini.lib import pynutil + + + +class PostProcessingFst: + """ + Finite state transducer that post-processing an entire sentence after verbalization is complete, e.g. + removes extra spaces around punctuation marks " ( one hundred and twenty three ) " -> "(one hundred and twenty three)" + + Args: + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + """ + + def __init__(self, cache_dir: str = None, overwrite_cache: bool = False): + + far_file = None + if cache_dir is not None and cache_dir != "None": + os.makedirs(cache_dir, exist_ok=True) + far_file = os.path.join(cache_dir, "en_tn_post_processing.far") + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode="r")["post_process_graph"] + else: + self.set_punct_dict() + self.fst = self.get_punct_postprocess_graph() + + if far_file: + generator_main(far_file, {"post_process_graph": self.fst}) + + def set_punct_dict(self): + self.punct_marks = { + "'": [ + "'", + '´', + 'ʹ', + 'ʻ', + 'ʼ', + 'ʽ', + 'ʾ', + 'ˈ', + 'ˊ', + 'ˋ', + '˴', + 'ʹ', + '΄', + '՚', + '՝', + 'י', + '׳', + 'ߴ', + 'ߵ', + 'ᑊ', + 'ᛌ', + '᾽', + '᾿', + '`', + '´', + '῾', + '‘', + '’', + '‛', + '′', + '‵', + 'ꞌ', + ''', + '`', + '𖽑', + '𖽒', + ], + } + + def get_punct_postprocess_graph(self): + """ + Returns graph to post process punctuation marks. + + {``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept. + By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks. + """ + punct_marks_all = PunctuationFst().punct_marks + + # no_space_before_punct assume no space before them + quotes = ["'", "\"", "``", "«"] + dashes = ["-", "—"] + brackets = ["<", "{", "("] + open_close_single_quotes = [ + ("`", "`"), + ] + + open_close_double_quotes = [('"', '"'), ("``", "``"), ("“", "”")] + open_close_symbols = open_close_single_quotes + open_close_double_quotes + allow_space_before_punct = ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols] + + no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct] + no_space_before_punct = pynini.union(*no_space_before_punct) + no_space_after_punct = pynini.union(*brackets) + delete_space = pynutil.delete(" ") + delete_space_optional = pynini.closure(delete_space, 0, 1) + + # non_punct allows space + # delete space before no_space_before_punct marks, if present + non_punct = pynini.difference(NEMO_CHAR, no_space_before_punct).optimize() + graph = ( + pynini.closure(non_punct) + + pynini.closure( + no_space_before_punct | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT) + ) + + pynini.closure(non_punct) + ) + graph = pynini.closure(graph).optimize() + graph = pynini.compose( + graph, pynini.cdrewrite(pynini.cross("``", '"'), "", "", NEMO_SIGMA).optimize() + ).optimize() + + # remove space after no_space_after_punct (even if there are no matching closing brackets) + no_space_after_punct = pynini.cdrewrite(delete_space, no_space_after_punct, NEMO_SIGMA, NEMO_SIGMA).optimize() + graph = pynini.compose(graph, no_space_after_punct).optimize() + + # remove space around text in quotes + single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT) + double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT) + quotes_graph = ( + single_quote + delete_space_optional + NEMO_ALPHA + NEMO_SIGMA + delete_space_optional + single_quote + ).optimize() + + # this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left + not_alpha = pynini.difference(NEMO_CHAR, NEMO_ALPHA).optimize() | pynutil.add_weight( + NEMO_SPACE, MIN_NEG_WEIGHT + ) + end = pynini.closure(pynutil.add_weight(not_alpha, MIN_NEG_WEIGHT)) + quotes_graph |= ( + double_quotes + + delete_space_optional + + NEMO_ALPHA + + NEMO_SIGMA + + delete_space_optional + + double_quotes + + end + ) + + quotes_graph = pynutil.add_weight(quotes_graph, MIN_NEG_WEIGHT) + quotes_graph = NEMO_SIGMA + pynini.closure(NEMO_SIGMA + quotes_graph + NEMO_SIGMA) + + graph = pynini.compose(graph, quotes_graph).optimize() + + # remove space between a word and a single quote followed by s + remove_space_around_single_quote = pynini.cdrewrite( + delete_space_optional + pynini.union(*self.punct_marks["'"]) + delete_space, + NEMO_ALPHA, + pynini.union("s ", "s[EOS]"), + NEMO_SIGMA, + ) + + graph = pynini.compose(graph, remove_space_around_single_quote).optimize() + return graph diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/roman.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/roman.py new file mode 100644 index 0000000..43faebe --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/roman.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst +from pynini.lib import pynutil + + +class RomanFst(GraphFst): + """ + Finite state transducer for verbalizing roman numerals + e.g. tokens { roman { integer: "one" } } -> one + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="roman", kind="verbalize", deterministic=deterministic) + suffix = OrdinalFst().suffix + + cardinal = pynini.closure(NEMO_NOT_QUOTE) + ordinal = pynini.compose(cardinal, suffix) + + graph = ( + pynutil.delete("key_cardinal: \"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + + pynini.accep(" ") + + pynutil.delete("integer: \"") + + cardinal + + pynutil.delete("\"") + ).optimize() + + graph |= ( + pynutil.delete("default_cardinal: \"default\" integer: \"") + cardinal + pynutil.delete("\"") + ).optimize() + + graph |= ( + pynutil.delete("default_ordinal: \"default\" integer: \"") + ordinal + pynutil.delete("\"") + ).optimize() + + graph |= ( + pynutil.delete("key_the_ordinal: \"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + + pynini.accep(" ") + + pynutil.delete("integer: \"") + + pynini.closure(pynutil.insert("the "), 0, 1) + + ordinal + + pynutil.delete("\"") + ).optimize() + + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/telephone.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/telephone.py new file mode 100644 index 0000000..4af7bbb --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/telephone.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space +from pynini.lib import pynutil + + +class TelephoneFst(GraphFst): + """ + Finite state transducer for verbalizing telephone numbers, e.g. + telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" } + -> one, one two three, one two three, five six seven eight, one + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="telephone", kind="verbalize", deterministic=deterministic) + + optional_country_code = pynini.closure( + pynutil.delete("country_code: \"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + + delete_space + + insert_space, + 0, + 1, + ) + + number_part = ( + pynutil.delete("number_part: \"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynini.closure(pynutil.add_weight(pynutil.delete(" "), -0.0001), 0, 1) + + pynutil.delete("\"") + ) + + optional_extension = pynini.closure( + delete_space + + insert_space + + pynutil.delete("extension: \"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\""), + 0, + 1, + ) + + graph = optional_country_code + number_part + optional_extension + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/time.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/time.py new file mode 100644 index 0000000..518c7df --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/time.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + NEMO_NOT_QUOTE, + NEMO_SIGMA, + GraphFst, + delete_space, + insert_space, +) +from pynini.lib import pynutil + + +class TimeFst(GraphFst): + """ + Finite state transducer for verbalizing time, e.g. + time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" } -> twelve thirty a m e s t + time { hours: "twelve" } -> twelve o'clock + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="time", kind="verbalize", deterministic=deterministic) + hour = ( + pynutil.delete("hours:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + minute = ( + pynutil.delete("minutes:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + suffix = ( + pynutil.delete("suffix:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + optional_suffix = pynini.closure(delete_space + insert_space + suffix, 0, 1) + zone = ( + pynutil.delete("zone:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + optional_zone = pynini.closure(delete_space + insert_space + zone, 0, 1) + second = ( + pynutil.delete("seconds:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_NOT_QUOTE, 1) + + pynutil.delete("\"") + ) + graph_hms = ( + hour + + pynutil.insert(" hours ") + + delete_space + + minute + + pynutil.insert(" minutes and ") + + delete_space + + second + + pynutil.insert(" seconds") + + optional_suffix + + optional_zone + ) + graph_hms @= pynini.cdrewrite( + pynutil.delete("o ") + | pynini.cross("one minutes", "one minute") + | pynini.cross("one seconds", "one second") + | pynini.cross("one hours", "one hour"), + pynini.union(" ", "[BOS]"), + "", + NEMO_SIGMA, + ) + graph = hour + delete_space + insert_space + minute + optional_suffix + optional_zone + graph |= hour + insert_space + pynutil.insert("o'clock") + optional_zone + graph |= hour + delete_space + insert_space + suffix + optional_zone + graph |= graph_hms + delete_tokens = self.delete_tokens(graph) + self.fst = delete_tokens.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize.py new file mode 100644 index 0000000..cd3b140 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_text_processing.text_normalization.en.graph_utils import GraphFst +from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst +from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst +from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst +from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst +from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst +from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst +from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst +from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst +from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst +from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst +from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst +from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst +from nemo_text_processing.text_normalization.en.verbalizers.whitelist import WhiteListFst + + +class VerbalizeFst(GraphFst): + """ + Composes other verbalizer grammars. + For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File. + More details to deployment at NeMo/tools/text_processing_deployment. + + Args: + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="verbalize", kind="verbalize", deterministic=deterministic) + cardinal = CardinalFst(deterministic=deterministic) + cardinal_graph = cardinal.fst + decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic) + decimal_graph = decimal.fst + ordinal = OrdinalFst(deterministic=deterministic) + ordinal_graph = ordinal.fst + fraction = FractionFst(deterministic=deterministic) + fraction_graph = fraction.fst + telephone_graph = TelephoneFst(deterministic=deterministic).fst + electronic_graph = ElectronicFst(deterministic=deterministic).fst + measure = MeasureFst(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic) + measure_graph = measure.fst + time_graph = TimeFst(deterministic=deterministic).fst + date_graph = DateFst(ordinal=ordinal, deterministic=deterministic).fst + money_graph = MoneyFst(decimal=decimal, deterministic=deterministic).fst + whitelist_graph = WhiteListFst(deterministic=deterministic).fst + + graph = ( + time_graph + | date_graph + | money_graph + | measure_graph + | ordinal_graph + | decimal_graph + | cardinal_graph + | telephone_graph + | electronic_graph + | fraction_graph + | whitelist_graph + ) + + roman_graph = RomanFst(deterministic=deterministic).fst + graph |= roman_graph + + if not deterministic: + abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst + graph |= abbreviation_graph + + self.fst = graph diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize_final.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize_final.py new file mode 100644 index 0000000..6564aff --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/verbalize_final.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import ( + GraphFst, + delete_extra_space, + delete_space, + generator_main, +) +from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst +from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst +from pynini.lib import pynutil + + + +class VerbalizeFinalFst(GraphFst): + """ + Finite state transducer that verbalizes an entire sentence, e.g. + tokens { name: "its" } tokens { time { hours: "twelve" minutes: "thirty" } } tokens { name: "now" } tokens { name: "." } -> its twelve thirty now . + + Args: + deterministic: if True will provide a single transduction option, + for False multiple options (used for audio-based normalization) + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + """ + + def __init__(self, deterministic: bool = True, cache_dir: str = None, overwrite_cache: bool = False): + super().__init__(name="verbalize_final", kind="verbalize", deterministic=deterministic) + + far_file = None + if cache_dir is not None and cache_dir != "None": + os.makedirs(cache_dir, exist_ok=True) + far_file = os.path.join(cache_dir, f"en_tn_{deterministic}_deterministic_verbalizer.far") + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode="r")["verbalize"] + + else: + verbalize = VerbalizeFst(deterministic=deterministic).fst + word = WordFst(deterministic=deterministic).fst + types = verbalize | word + + if deterministic: + graph = ( + pynutil.delete("tokens") + + delete_space + + pynutil.delete("{") + + delete_space + + types + + delete_space + + pynutil.delete("}") + ) + else: + graph = delete_space + types + delete_space + + graph = delete_space + pynini.closure(graph + delete_extra_space) + graph + delete_space + + self.fst = graph.optimize() + if far_file: + generator_main(far_file, {"verbalize": self.fst}) + diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/whitelist.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/whitelist.py new file mode 100644 index 0000000..96aa207 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/whitelist.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space +from pynini.lib import pynutil + + +class WhiteListFst(GraphFst): + """ + Finite state transducer for verbalizing whitelist + e.g. tokens { name: "misses" } } -> misses + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="whitelist", kind="verbalize", deterministic=deterministic) + graph = ( + pynutil.delete("name:") + + delete_space + + pynutil.delete("\"") + + pynini.closure(NEMO_CHAR - " ", 1) + + pynutil.delete("\"") + ) + graph = graph @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA) + self.fst = graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/word.py b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/word.py new file mode 100644 index 0000000..e124f42 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/en/verbalizers/word.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pynini +from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space +from pynini.lib import pynutil + + +class WordFst(GraphFst): + """ + Finite state transducer for verbalizing word + e.g. tokens { name: "sleep" } -> sleep + + Args: + deterministic: if True will provide a single transduction option, + for False multiple transduction are generated (used for audio-based normalization) + """ + + def __init__(self, deterministic: bool = True): + super().__init__(name="word", kind="verbalize", deterministic=deterministic) + chars = pynini.closure(NEMO_CHAR - " ", 1) + char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"") + graph = char @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA) + + self.fst = graph.optimize() diff --git a/utils/speechio/nemo_text_processing/text_normalization/normalize.py b/utils/speechio/nemo_text_processing/text_normalization/normalize.py new file mode 100644 index 0000000..d22ef8c --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/normalize.py @@ -0,0 +1,479 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import re +from argparse import ArgumentParser +from collections import OrderedDict +from math import factorial +from time import perf_counter +from typing import Dict, List, Union + +import pynini +import regex +from nemo_text_processing.text_normalization.data_loader_utils import ( + load_file, + post_process_punct, + pre_process, + write_file, +) +from nemo_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser +from pynini.lib.rewrite import top_rewrite + +SPACE_DUP = re.compile(' {2,}') + + +class Normalizer: + """ + Normalizer class that converts text from written to spoken form. + Useful for TTS preprocessing. + + Args: + input_case: expected input capitalization + lang: language specifying the TN rules, by default: English + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + whitelist: path to a file with whitelist replacements + post_process: WFST-based post processing, e.g. to remove extra spaces added during TN. + Note: punct_post_process flag in normalize() supports all languages. + """ + + def __init__( + self, + input_case: str, + lang: str = 'en', + deterministic: bool = True, + cache_dir: str = None, + overwrite_cache: bool = False, + whitelist: str = None, + lm: bool = False, + post_process: bool = True, + ): + assert input_case in ["lower_cased", "cased"] + + self.post_processor = None + + if lang == "en": + from nemo_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst + from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst + + if post_process: + self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache) + + if deterministic: + from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst + else: + if lm: + from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst + else: + from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import ( + ClassifyFst, + ) + + elif lang == 'ru': + # Ru TN only support non-deterministic cases and produces multiple normalization options + # use normalize_with_audio.py + from nemo_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst + from nemo_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst + elif lang == 'de': + from nemo_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst + from nemo_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst + elif lang == 'es': + from nemo_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst + from nemo_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst + self.tagger = ClassifyFst( + input_case=input_case, + deterministic=deterministic, + cache_dir=cache_dir, + overwrite_cache=overwrite_cache, + whitelist=whitelist, + ) + + self.verbalizer = VerbalizeFinalFst( + deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache + ) + + self.parser = TokenParser() + self.lang = lang + + self.processor = 0 + + def __process_batch(self, batch, verbose, punct_pre_process, punct_post_process): + """ + Normalizes batch of text sequences + Args: + batch: list of texts + verbose: whether to print intermediate meta information + punct_pre_process: whether to do punctuation pre processing + punct_post_process: whether to do punctuation post processing + """ + normalized_lines = [ + self.normalize( + text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process + ) + for text in tqdm(batch) + ] + return normalized_lines + + def _estimate_number_of_permutations_in_nested_dict( + self, token_group: Dict[str, Union[OrderedDict, str, bool]] + ) -> int: + num_perms = 1 + for k, inner in token_group.items(): + if isinstance(inner, dict): + num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner) + num_perms *= factorial(len(token_group)) + return num_perms + + def _split_tokens_to_reduce_number_of_permutations( + self, tokens: List[dict], max_number_of_permutations_per_split: int = 729 + ) -> List[List[dict]]: + """ + Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite + tokens permutations does not exceed ``max_number_of_permutations_per_split``. + + For example, + + .. code-block:: python + tokens = [ + {"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}, + {"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}, + ] + split = normalizer._split_tokens_to_reduce_number_of_permutations( + tokens, max_number_of_permutations_per_split=6 + ) + assert split == [ + [{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}], + [{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}], + ] + + Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total + number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6, + so input sequence of tokens is split into 2 smaller sequences. + + Args: + tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested. + max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number + of permutations which can be generated from input sequence of tokens. + + Returns: + :obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split. + """ + splits = [] + prev_end_of_split = 0 + current_number_of_permutations = 1 + for i, token_group in enumerate(tokens): + n = self._estimate_number_of_permutations_in_nested_dict(token_group) + if n * current_number_of_permutations > max_number_of_permutations_per_split: + splits.append(tokens[prev_end_of_split:i]) + prev_end_of_split = i + current_number_of_permutations = 1 + if n > max_number_of_permutations_per_split: + raise ValueError( + f"Could not split token list with respect to condition that every split can generate number of " + f"permutations less or equal to " + f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. " + f"There is an unsplittable token group that generates more than " + f"{max_number_of_permutations_per_split} permutations. Try to increase " + f"`max_number_of_permutations_per_split` parameter." + ) + current_number_of_permutations *= n + splits.append(tokens[prev_end_of_split:]) + assert sum([len(s) for s in splits]) == len(tokens) + return splits + + def normalize( + self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False + ) -> str: + """ + Main function. Normalizes tokens from written to spoken form + e.g. 12 kg -> twelve kilograms + + Args: + text: string that may include semiotic classes + verbose: whether to print intermediate meta information + punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ] + punct_post_process: whether to normalize punctuation + + Returns: spoken form + """ + + original_text = text + if punct_pre_process: + text = pre_process(text) + text = text.strip() + if not text: + if verbose: + print(text) + return text + text = pynini.escape(text) + tagged_lattice = self.find_tags(text) + tagged_text = self.select_tag(tagged_lattice) + if verbose: + print(tagged_text) + self.parser(tagged_text) + tokens = self.parser.parse() + split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens) + output = "" + for s in split_tokens: + tags_reordered = self.generate_permutations(s) + verbalizer_lattice = None + for tagged_text in tags_reordered: + tagged_text = pynini.escape(tagged_text) + + verbalizer_lattice = self.find_verbalizer(tagged_text) + if verbalizer_lattice.num_states() != 0: + break + if verbalizer_lattice is None: + raise ValueError(f"No permutations were generated from tokens {s}") + output += ' ' + self.select_verbalizer(verbalizer_lattice) + output = SPACE_DUP.sub(' ', output[1:]) + + if self.lang == "en" and hasattr(self, 'post_processor'): + output = self.post_process(output) + + if punct_post_process: + # do post-processing based on Moses detokenizer + if self.processor: + output = self.processor.moses_detokenizer.detokenize([output], unescape=False) + output = post_process_punct(input=original_text, normalized_text=output) + else: + print("NEMO_NLP collection is not available: skipping punctuation post_processing") + + return output + + def split_text_into_sentences(self, text: str) -> List[str]: + """ + Split text into sentences. + + Args: + text: text + + Returns list of sentences + """ + lower_case_unicode = '' + upper_case_unicode = '' + if self.lang == "ru": + lower_case_unicode = '\u0430-\u04FF' + upper_case_unicode = '\u0410-\u042F' + + # Read and split transcript by utterance (roughly, sentences) + split_pattern = f"(? List[str]: + """ + Creates reorderings of dictionary elements and serializes as strings + + Args: + d: (nested) dictionary of key value pairs + + Return permutations of different string serializations of key value pairs + """ + l = [] + if PRESERVE_ORDER_KEY in d.keys(): + d_permutations = [d.items()] + else: + d_permutations = itertools.permutations(d.items()) + for perm in d_permutations: + subl = [""] + for k, v in perm: + if isinstance(v, str): + subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])] + elif isinstance(v, OrderedDict): + rec = self._permute(v) + subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])] + elif isinstance(v, bool): + subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])] + else: + raise ValueError() + l.extend(subl) + return l + + def generate_permutations(self, tokens: List[dict]): + """ + Generates permutations of string serializations of list of dictionaries + + Args: + tokens: list of dictionaries + + Returns string serialization of list of dictionaries + """ + + def _helper(prefix: str, tokens: List[dict], idx: int): + """ + Generates permutations of string serializations of given dictionary + + Args: + tokens: list of dictionaries + prefix: prefix string + idx: index of next dictionary + + Returns string serialization of dictionary + """ + if idx == len(tokens): + yield prefix + return + token_options = self._permute(tokens[idx]) + for token_option in token_options: + yield from _helper(prefix + token_option, tokens, idx + 1) + + return _helper("", tokens, 0) + + def find_tags(self, text: str) -> 'pynini.FstLike': + """ + Given text use tagger Fst to tag text + + Args: + text: sentence + + Returns: tagged lattice + """ + lattice = text @ self.tagger.fst + return lattice + + def select_tag(self, lattice: 'pynini.FstLike') -> str: + """ + Given tagged lattice return shortest path + + Args: + tagged_text: tagged text + + Returns: shortest path + """ + tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string() + return tagged_text + + def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike': + """ + Given tagged text creates verbalization lattice + This is context-independent. + + Args: + tagged_text: input text + + Returns: verbalized lattice + """ + lattice = tagged_text @ self.verbalizer.fst + return lattice + + def select_verbalizer(self, lattice: 'pynini.FstLike') -> str: + """ + Given verbalized lattice return shortest path + + Args: + lattice: verbalization lattice + + Returns: shortest path + """ + output = pynini.shortestpath(lattice, nshortest=1, unique=True).string() + # lattice = output @ self.verbalizer.punct_graph + # output = pynini.shortestpath(lattice, nshortest=1, unique=True).string() + return output + + def post_process(self, normalized_text: 'pynini.FstLike') -> str: + """ + Runs post processing graph on normalized text + + Args: + normalized_text: normalized text + + Returns: shortest path + """ + normalized_text = normalized_text.strip() + if not normalized_text: + return normalized_text + normalized_text = pynini.escape(normalized_text) + + if self.post_processor is not None: + normalized_text = top_rewrite(normalized_text, self.post_processor.fst) + return normalized_text + + +def parse_args(): + parser = ArgumentParser() + input = parser.add_mutually_exclusive_group() + input.add_argument("--text", dest="input_string", help="input string", type=str) + input.add_argument("--input_file", dest="input_file", help="input file path", type=str) + parser.add_argument('--output_file', dest="output_file", help="output file path", type=str) + parser.add_argument("--language", help="language", choices=["en", "de", "es"], default="en", type=str) + parser.add_argument( + "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str + ) + parser.add_argument("--verbose", help="print info for debugging", action='store_true') + parser.add_argument( + "--punct_post_process", + help="set to True to enable punctuation post processing to match input.", + action="store_true", + ) + parser.add_argument( + "--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true" + ) + parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true") + parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str) + parser.add_argument( + "--cache_dir", + help="path to a dir with .far grammar file. Set to None to avoid using cache", + default=None, + type=str, + ) + return parser.parse_args() + + +if __name__ == "__main__": + start_time = perf_counter() + + args = parse_args() + whitelist = os.path.abspath(args.whitelist) if args.whitelist else None + + if not args.input_string and not args.input_file: + raise ValueError("Either `--text` or `--input_file` required") + + normalizer = Normalizer( + input_case=args.input_case, + cache_dir=args.cache_dir, + overwrite_cache=args.overwrite_cache, + whitelist=whitelist, + lang=args.language, + ) + if args.input_string: + print( + normalizer.normalize( + args.input_string, + verbose=args.verbose, + punct_pre_process=args.punct_pre_process, + punct_post_process=args.punct_post_process, + ) + ) + elif args.input_file: + print("Loading data: " + args.input_file) + data = load_file(args.input_file) + + print("- Data: " + str(len(data)) + " sentences") + normalizer_prediction = normalizer.normalize_list( + data, + verbose=args.verbose, + punct_pre_process=args.punct_pre_process, + punct_post_process=args.punct_post_process, + ) + if args.output_file: + write_file(args.output_file, normalizer_prediction) + print(f"- Normalized. Writing out to {args.output_file}") + else: + print(normalizer_prediction) + + print(f"Execution time: {perf_counter() - start_time:.02f} sec") diff --git a/utils/speechio/nemo_text_processing/text_normalization/normalize_with_audio.py b/utils/speechio/nemo_text_processing/text_normalization/normalize_with_audio.py new file mode 100644 index 0000000..89927b2 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/normalize_with_audio.py @@ -0,0 +1,543 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from argparse import ArgumentParser +from glob import glob +from typing import List, Tuple + +import pynini +from joblib import Parallel, delayed +from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process +from nemo_text_processing.text_normalization.normalize import Normalizer +from pynini.lib import rewrite +from tqdm import tqdm + +try: + from nemo.collections.asr.metrics.wer import word_error_rate + from nemo.collections.asr.models import ASRModel + + ASR_AVAILABLE = True +except (ModuleNotFoundError, ImportError): + ASR_AVAILABLE = False + + +""" +The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output +(most of the semiotic classes use deterministic=False flag). + +To run this script with a .json manifest file, the manifest file should contain the following fields: + "audio_data" - path to the audio file + "text" - raw text + "pred_text" - ASR model prediction + + See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions + + When the manifest is ready, run: + python normalize_with_audio.py \ + --audio_data PATH/TO/MANIFEST.JSON \ + --language en + + +To run with a single audio file, specify path to audio and text with: + python normalize_with_audio.py \ + --audio_data PATH/TO/AUDIO.WAV \ + --language en \ + --text raw text OR PATH/TO/.TXT/FILE + --model QuartzNet15x5Base-En \ + --verbose + +To see possible normalization options for a text input without an audio file (could be used for debugging), run: + python python normalize_with_audio.py --text "RAW TEXT" + +Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference +""" + + +class NormalizerWithAudio(Normalizer): + """ + Normalizer class that converts text from written to spoken form. + Useful for TTS preprocessing. + + Args: + input_case: expected input capitalization + lang: language + cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. + overwrite_cache: set to True to overwrite .far files + whitelist: path to a file with whitelist replacements + post_process: WFST-based post processing, e.g. to remove extra spaces added during TN. + Note: punct_post_process flag in normalize() supports all languages. + """ + + def __init__( + self, + input_case: str, + lang: str = 'en', + cache_dir: str = None, + overwrite_cache: bool = False, + whitelist: str = None, + lm: bool = False, + post_process: bool = True, + ): + + super().__init__( + input_case=input_case, + lang=lang, + deterministic=False, + cache_dir=cache_dir, + overwrite_cache=overwrite_cache, + whitelist=whitelist, + lm=lm, + post_process=post_process, + ) + self.lm = lm + + def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str: + """ + Main function. Normalizes tokens from written to spoken form + e.g. 12 kg -> twelve kilograms + + Args: + text: string that may include semiotic classes + n_tagged: number of tagged options to consider, -1 - to get all possible tagged options + punct_post_process: whether to normalize punctuation + verbose: whether to print intermediate meta information + + Returns: + normalized text options (usually there are multiple ways of normalizing a given semiotic class) + """ + + if len(text.split()) > 500: + raise ValueError( + "Your input is too long. Please split up the input into sentences, " + "or strings with fewer than 500 words" + ) + + original_text = text + text = pre_process(text) # to handle [] + + text = text.strip() + if not text: + if verbose: + print(text) + return text + text = pynini.escape(text) + print(text) + + if self.lm: + if self.lang not in ["en"]: + raise ValueError(f"{self.lang} is not supported in LM mode") + + if self.lang == "en": + # this to keep arpabet phonemes in the list of options + if "[" in text and "]" in text: + + lattice = rewrite.rewrite_lattice(text, self.tagger.fst) + else: + try: + lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits) + except pynini.lib.rewrite.Error: + lattice = rewrite.rewrite_lattice(text, self.tagger.fst) + lattice = rewrite.lattice_to_nshortest(lattice, n_tagged) + tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()] + tagged_texts.sort(key=lambda x: x[1]) + tagged_texts, weights = list(zip(*tagged_texts)) + else: + tagged_texts = self._get_tagged_text(text, n_tagged) + # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between + if self.lang == "en": + normalized_texts = tagged_texts + normalized_texts = [self.post_process(text) for text in normalized_texts] + else: + normalized_texts = [] + for tagged_text in tagged_texts: + self._verbalize(tagged_text, normalized_texts, verbose=verbose) + + if len(normalized_texts) == 0: + raise ValueError() + + if punct_post_process: + # do post-processing based on Moses detokenizer + if self.processor: + normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts] + normalized_texts = [ + post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts + ] + + if self.lm: + remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1]) + normalized_texts, weights = zip(*remove_dup) + return list(normalized_texts), weights + + normalized_texts = set(normalized_texts) + return normalized_texts + + def _get_tagged_text(self, text, n_tagged): + """ + Returns text after tokenize and classify + Args; + text: input text + n_tagged: number of tagged options to consider, -1 - return all possible tagged options + """ + if n_tagged == -1: + if self.lang == "en": + # this to keep arpabet phonemes in the list of options + if "[" in text and "]" in text: + tagged_texts = rewrite.rewrites(text, self.tagger.fst) + else: + try: + tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits) + except pynini.lib.rewrite.Error: + tagged_texts = rewrite.rewrites(text, self.tagger.fst) + else: + tagged_texts = rewrite.rewrites(text, self.tagger.fst) + else: + if self.lang == "en": + # this to keep arpabet phonemes in the list of options + if "[" in text and "]" in text: + tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) + else: + try: + # try self.tagger graph that produces output without digits + tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged) + except pynini.lib.rewrite.Error: + tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) + else: + tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) + return tagged_texts + + def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False): + """ + Verbalizes tagged text + + Args: + tagged_text: text with tags + normalized_texts: list of possible normalization options + verbose: if true prints intermediate classification results + """ + + def get_verbalized_text(tagged_text): + return rewrite.rewrites(tagged_text, self.verbalizer.fst) + + self.parser(tagged_text) + tokens = self.parser.parse() + tags_reordered = self.generate_permutations(tokens) + for tagged_text_reordered in tags_reordered: + try: + tagged_text_reordered = pynini.escape(tagged_text_reordered) + normalized_texts.extend(get_verbalized_text(tagged_text_reordered)) + if verbose: + print(tagged_text_reordered) + + except pynini.lib.rewrite.Error: + continue + + def select_best_match( + self, + normalized_texts: List[str], + input_text: str, + pred_text: str, + verbose: bool = False, + remove_punct: bool = False, + cer_threshold: int = 100, + ): + """ + Selects the best normalization option based on the lowest CER + + Args: + normalized_texts: normalized text options + input_text: input text + pred_text: ASR model transcript of the audio file corresponding to the normalized text + verbose: whether to print intermediate meta information + remove_punct: whether to remove punctuation before calculating CER + cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed + + Returns: + normalized text with the lowest CER and CER value + """ + if pred_text == "": + return input_text, cer_threshold + + normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct) + normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1]) + normalized_text, cer = normalized_texts_cer[0] + + if cer > cer_threshold: + return input_text, cer + + if verbose: + print('-' * 30) + for option in normalized_texts: + print(option) + print('-' * 30) + return normalized_text, cer + + +def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]: + """ + Calculates character error rate (CER) + + Args: + normalized_texts: normalized text options + pred_text: ASR model output + + Returns: normalized options with corresponding CER + """ + normalized_options = [] + for text in normalized_texts: + text_clean = text.replace('-', ' ').lower() + if remove_punct: + for punct in "!?:;,.-()*+-/<=>@^_": + text_clean = text_clean.replace(punct, "") + cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2) + normalized_options.append((text, cer)) + return normalized_options + + +def get_asr_model(asr_model): + """ + Returns ASR Model + + Args: + asr_model: NeMo ASR model + """ + if os.path.exists(args.model): + asr_model = ASRModel.restore_from(asr_model) + elif args.model in ASRModel.get_available_model_names(): + asr_model = ASRModel.from_pretrained(asr_model) + else: + raise ValueError( + f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}' + ) + return asr_model + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str) + parser.add_argument( + "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str + ) + parser.add_argument( + "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str + ) + parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest") + parser.add_argument( + '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint' + ) + parser.add_argument( + "--n_tagged", + type=int, + default=30, + help="number of tagged options to consider, -1 - return all possible tagged options", + ) + parser.add_argument("--verbose", help="print info for debugging", action="store_true") + parser.add_argument( + "--no_remove_punct_for_cer", + help="Set to True to NOT remove punctuation before calculating CER", + action="store_true", + ) + parser.add_argument( + "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true" + ) + parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true") + parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str) + parser.add_argument( + "--cache_dir", + help="path to a dir with .far grammar file. Set to None to avoid using cache", + default=None, + type=str, + ) + parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs") + parser.add_argument( + "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now." + ) + parser.add_argument( + "--cer_threshold", + default=100, + type=int, + help="if CER for pred_text is above the cer_threshold, no normalization will be performed", + ) + parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process") + return parser.parse_args() + + +def _normalize_line( + normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold +): + line = json.loads(line) + pred_text = line["pred_text"] + + normalized_texts = normalizer.normalize( + text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process, + ) + + normalized_texts = set(normalized_texts) + normalized_text, cer = normalizer.select_best_match( + normalized_texts=normalized_texts, + input_text=line["text"], + pred_text=pred_text, + verbose=verbose, + remove_punct=remove_punct, + cer_threshold=cer_threshold, + ) + line["nemo_normalized"] = normalized_text + line["CER_nemo_normalized"] = cer + return line + + +def normalize_manifest( + normalizer, + audio_data: str, + n_jobs: int, + n_tagged: int, + remove_punct: bool, + punct_post_process: bool, + batch_size: int, + cer_threshold: int, +): + """ + Args: + args.audio_data: path to .json manifest file. + """ + + def __process_batch(batch_idx: int, batch: List[str], dir_name: str): + """ + Normalizes batch of text sequences + Args: + batch: list of texts + batch_idx: batch index + dir_name: path to output directory to save results + """ + normalized_lines = [ + _normalize_line( + normalizer, + n_tagged, + verbose=False, + line=line, + remove_punct=remove_punct, + punct_post_process=punct_post_process, + cer_threshold=cer_threshold, + ) + for line in tqdm(batch) + ] + + with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out: + for line in normalized_lines: + f_out.write(json.dumps(line, ensure_ascii=False) + '\n') + + print(f"Batch -- {batch_idx} -- is complete") + + manifest_out = audio_data.replace('.json', '_normalized.json') + with open(audio_data, 'r') as f: + lines = f.readlines() + + print(f'Normalizing {len(lines)} lines of {audio_data}...') + + # to save intermediate results to a file + batch = min(len(lines), batch_size) + + tmp_dir = manifest_out.replace(".json", "_parts") + os.makedirs(tmp_dir, exist_ok=True) + + Parallel(n_jobs=n_jobs)( + delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir) + for idx, i in enumerate(range(0, len(lines), batch)) + ) + + # aggregate all intermediate files + with open(manifest_out, "w") as f_out: + for batch_f in sorted(glob(f"{tmp_dir}/*.json")): + with open(batch_f, "r") as f_in: + lines = f_in.read() + f_out.write(lines) + + print(f'Normalized version saved at {manifest_out}') + + +if __name__ == "__main__": + args = parse_args() + + if not ASR_AVAILABLE and args.audio_data: + raise ValueError("NeMo ASR collection is not installed.") + start = time.time() + args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None + if args.text is not None: + normalizer = NormalizerWithAudio( + input_case=args.input_case, + lang=args.language, + cache_dir=args.cache_dir, + overwrite_cache=args.overwrite_cache, + whitelist=args.whitelist, + lm=args.lm, + ) + + if os.path.exists(args.text): + with open(args.text, 'r') as f: + args.text = f.read().strip() + normalized_texts = normalizer.normalize( + text=args.text, + verbose=args.verbose, + n_tagged=args.n_tagged, + punct_post_process=not args.no_punct_post_process, + ) + + if not normalizer.lm: + normalized_texts = set(normalized_texts) + if args.audio_data: + asr_model = get_asr_model(args.model) + pred_text = asr_model.transcribe([args.audio_data])[0] + normalized_text, cer = normalizer.select_best_match( + normalized_texts=normalized_texts, + pred_text=pred_text, + input_text=args.text, + verbose=args.verbose, + remove_punct=not args.no_remove_punct_for_cer, + cer_threshold=args.cer_threshold, + ) + print(f"Transcript: {pred_text}") + print(f"Normalized: {normalized_text}") + else: + print("Normalization options:") + for norm_text in normalized_texts: + print(norm_text) + elif not os.path.exists(args.audio_data): + raise ValueError(f"{args.audio_data} not found.") + elif args.audio_data.endswith('.json'): + normalizer = NormalizerWithAudio( + input_case=args.input_case, + lang=args.language, + cache_dir=args.cache_dir, + overwrite_cache=args.overwrite_cache, + whitelist=args.whitelist, + ) + normalize_manifest( + normalizer=normalizer, + audio_data=args.audio_data, + n_jobs=args.n_jobs, + n_tagged=args.n_tagged, + remove_punct=not args.no_remove_punct_for_cer, + punct_post_process=not args.no_punct_post_process, + batch_size=args.batch_size, + cer_threshold=args.cer_threshold, + ) + else: + raise ValueError( + "Provide either path to .json manifest in '--audio_data' OR " + + "'--audio_data' path to audio file and '--text' path to a text file OR" + "'--text' string text (for debugging without audio)" + ) + print(f'Execution time: {round((time.time() - start)/60, 2)} min.') diff --git a/utils/speechio/nemo_text_processing/text_normalization/run_evaluate.py b/utils/speechio/nemo_text_processing/text_normalization/run_evaluate.py new file mode 100644 index 0000000..5f23dbd --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/run_evaluate.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from nemo_text_processing.text_normalization.data_loader_utils import ( + evaluate, + known_types, + load_files, + training_data_to_sentences, + training_data_to_tokens, +) +from nemo_text_processing.text_normalization.normalize import Normalizer + + +''' +Runs Evaluation on data in the format of : \t\t<`self` if trivial class or normalized text> +like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish +''' + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--input", help="input file path", type=str) + parser.add_argument("--lang", help="language", choices=['en'], default="en", type=str) + parser.add_argument( + "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str + ) + parser.add_argument( + "--cat", + dest="category", + help="focus on class only (" + ", ".join(known_types) + ")", + type=str, + default=None, + choices=known_types, + ) + parser.add_argument("--filter", action='store_true', help="clean data for normalization purposes") + return parser.parse_args() + + +if __name__ == "__main__": + # Example usage: + # python run_evaluate.py --input= --cat= --filter + args = parse_args() + if args.lang == 'en': + from nemo_text_processing.text_normalization.en.clean_eval_data import filter_loaded_data + file_path = args.input + normalizer = Normalizer(input_case=args.input_case, lang=args.lang) + + print("Loading training data: " + file_path) + training_data = load_files([file_path]) + + if args.filter: + training_data = filter_loaded_data(training_data) + + if args.category is None: + print("Sentence level evaluation...") + sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data) + print("- Data: " + str(len(sentences_normalized)) + " sentences") + sentences_prediction = normalizer.normalize_list(sentences_un_normalized) + print("- Normalized. Evaluating...") + sentences_accuracy = evaluate( + preds=sentences_prediction, labels=sentences_normalized, input=sentences_un_normalized + ) + print("- Accuracy: " + str(sentences_accuracy)) + + print("Token level evaluation...") + tokens_per_type = training_data_to_tokens(training_data, category=args.category) + token_accuracy = {} + for token_type in tokens_per_type: + print("- Token type: " + token_type) + tokens_un_normalized, tokens_normalized = tokens_per_type[token_type] + print(" - Data: " + str(len(tokens_normalized)) + " tokens") + tokens_prediction = normalizer.normalize_list(tokens_un_normalized) + print(" - Denormalized. Evaluating...") + token_accuracy[token_type] = evaluate( + preds=tokens_prediction, labels=tokens_normalized, input=tokens_un_normalized + ) + print(" - Accuracy: " + str(token_accuracy[token_type])) + token_count_per_type = {token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type} + token_weighted_accuracy = [ + token_count_per_type[token_type] * accuracy for token_type, accuracy in token_accuracy.items() + ] + print("- Accuracy: " + str(sum(token_weighted_accuracy) / sum(token_count_per_type.values()))) + print(" - Total: " + str(sum(token_count_per_type.values())), '\n') + + print(" - Total: " + str(sum(token_count_per_type.values())), '\n') + + for token_type in token_accuracy: + if token_type not in known_types: + raise ValueError("Unexpected token type: " + token_type) + + if args.category is None: + c1 = ['Class', 'sent level'] + known_types + c2 = ['Num Tokens', len(sentences_normalized)] + [ + token_count_per_type[known_type] if known_type in tokens_per_type else '0' for known_type in known_types + ] + c3 = ['Normalization', sentences_accuracy] + [ + token_accuracy[known_type] if known_type in token_accuracy else '0' for known_type in known_types + ] + + for i in range(len(c1)): + print(f'{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}') + else: + print(f'numbers\t{token_count_per_type[args.category]}') + print(f'Normalization\t{token_accuracy[args.category]}') diff --git a/utils/speechio/nemo_text_processing/text_normalization/token_parser.py b/utils/speechio/nemo_text_processing/text_normalization/token_parser.py new file mode 100644 index 0000000..d3f7fd9 --- /dev/null +++ b/utils/speechio/nemo_text_processing/text_normalization/token_parser.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import string +from collections import OrderedDict +from typing import Dict, List, Union + +PRESERVE_ORDER_KEY = "preserve_order" +EOS = "" + + +class TokenParser: + """ + Parses tokenized/classified text, e.g. 'tokens { money { integer: "20" currency: "$" } } tokens { name: "left"}' + + Args + text: tokenized text + """ + + def __call__(self, text): + """ + Setup function + + Args: + text: text to be parsed + + """ + self.text = text + self.len_text = len(text) + self.char = text[0] # cannot handle empty string + self.index = 0 + + def parse(self) -> List[dict]: + """ + Main function. Implements grammar: + A -> space F space F space F ... space + + Returns list of dictionaries + """ + l = list() + while self.parse_ws(): + token = self.parse_token() + if not token: + break + l.append(token) + return l + + def parse_token(self) -> Dict[str, Union[str, dict]]: + """ + Implements grammar: + F-> no_space KG no_space + + Returns: K, G as dictionary values + """ + d = OrderedDict() + key = self.parse_string_key() + if key is None: + return None + self.parse_ws() + if key == PRESERVE_ORDER_KEY: + self.parse_char(":") + self.parse_ws() + value = self.parse_chars("true") + else: + value = self.parse_token_value() + + d[key] = value + return d + + def parse_token_value(self) -> Union[str, dict]: + """ + Implements grammar: + G-> no_space :"VALUE" no_space | no_space {A} no_space + + Returns: string or dictionary + """ + if self.char == ":": + self.parse_char(":") + self.parse_ws() + self.parse_char("\"") + value_string = self.parse_string_value() + self.parse_char("\"") + return value_string + elif self.char == "{": + d = OrderedDict() + self.parse_char("{") + list_token_dicts = self.parse() + # flatten tokens + for tok_dict in list_token_dicts: + for k, v in tok_dict.items(): + d[k] = v + self.parse_char("}") + return d + else: + raise ValueError() + + def parse_char(self, exp) -> bool: + """ + Parses character + + Args: + exp: character to read in + + Returns true if successful + """ + assert self.char == exp + self.read() + return True + + def parse_chars(self, exp) -> bool: + """ + Parses characters + + Args: + exp: characters to read in + + Returns true if successful + """ + ok = False + for x in exp: + ok |= self.parse_char(x) + return ok + + def parse_string_key(self) -> str: + """ + Parses string key, can only contain ascii and '_' characters + + Returns parsed string key + """ + assert self.char not in string.whitespace and self.char != EOS + incl_criterium = string.ascii_letters + "_" + l = [] + while self.char in incl_criterium: + l.append(self.char) + if not self.read(): + raise ValueError() + + if not l: + return None + return "".join(l) + + def parse_string_value(self) -> str: + """ + Parses string value, ends with quote followed by space + + Returns parsed string value + """ + assert self.char not in string.whitespace and self.char != EOS + l = [] + while self.char != "\"" or self.text[self.index + 1] != " ": + l.append(self.char) + if not self.read(): + raise ValueError() + + if not l: + return None + return "".join(l) + + def parse_ws(self): + """ + Deletes whitespaces. + + Returns true if not EOS after parsing + """ + not_eos = self.char != EOS + while not_eos and self.char == " ": + not_eos = self.read() + return not_eos + + def read(self): + """ + Reads in next char. + + Returns true if not EOS + """ + if self.index < self.len_text - 1: # should be unique + self.index += 1 + self.char = self.text[self.index] + return True + self.char = EOS + return False diff --git a/utils/speechio/textnorm_en.py b/utils/speechio/textnorm_en.py new file mode 100644 index 0000000..aaf1fc7 --- /dev/null +++ b/utils/speechio/textnorm_en.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# coding=utf-8 +# Copyright 2022 Ruiqi WANG, Jinpeng LI, Jiayu DU +# +# only tested and validated on pynini v2.1.5 via : 'conda install -c conda-forge pynini' +# pynini v2.1.0 doesn't work +# + +import argparse +import os +import string +import sys + +from nemo_text_processing.text_normalization.normalize import Normalizer + + +def read_interjections(filepath): + interjections = [] + with open(filepath) as f: + for line in f: + words = [x.strip() for x in line.split(',')] + interjections += [w for w in words] + [w.upper() for w in words] + [w.lower() for w in words] + return list(set(interjections)) # deduplicated + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + args = p.parse_args() + + nemo_tn_en = Normalizer(input_case='lower_cased', lang='en') + + itj = read_interjections(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'interjections_en.csv')) + itj_map = {x: True for x in itj} + + certain_single_quote_items = ["\"'", "'?", "'!", "'.", "?'", "!'", ".'", "''", "'", "'"] + single_quote_removed_items = [x.replace("'", '') for x in certain_single_quote_items] + + puncts_to_remove = string.punctuation.replace("'", '') + "—–“”" + puncts_trans = str.maketrans(puncts_to_remove, ' ' * len(puncts_to_remove), '') + + n = 0 + with open(args.ifile, 'r', encoding='utf8') as fi, open(args.ofile, 'w+', encoding='utf8') as fo: + for line in fi: + if args.has_key: + cols = line.strip().split(maxsplit=1) + key, text = cols[0].strip(), cols[1].strip() if len(cols) == 2 else '' + else: + text = line.strip() + + text = text.replace("‘", "'").replace("’", "'") + + # nemo text normalization + # modifications to NeMo: + # 1. added UK to US conversion: nemo_text_processing/text_normalization/en/data/whitelist/UK_to_US.tsv + # 2. swith 'oh' to 'o' in year TN to avoid confusion with interjections, e.g.: + # 1805: eighteen oh five -> eighteen o five + text = nemo_tn_en.normalize(text.lower()) + + # Punctuations + # NOTE(2022.10 Jiayu): + # Single quote removal is not perfect. + # ' needs to be reserved for: + # Abbreviations: + # I'm, don't, she'd, 'cause, Sweet Child o' Mine, Guns N' Roses, ... + # Possessions: + # John's, the king's, parents', ... + text = '' + text + '' + for x, y in zip(certain_single_quote_items, single_quote_removed_items): + text = text.replace(x, y) + text = text.replace('', '').replace('', '') + + text = text.translate(puncts_trans).replace(" ' ", " ") + + # Interjections + text = ' '.join([x for x in text.strip().split() if x not in itj_map]) + + # Cases + if args.to_upper and args.to_lower: + sys.stderr.write('text norm: to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + if args.has_key: + print(key + '\t' + text, file=fo) + else: + print(text, file=fo) + + n += 1 + if n % args.log_interval == 0: + print(f'text norm: {n} lines done.', file=sys.stderr) + print(f'text norm: {n} lines done in total.', file=sys.stderr) diff --git a/utils/speechio/textnorm_zh.py b/utils/speechio/textnorm_zh.py new file mode 100644 index 0000000..9a671e6 --- /dev/null +++ b/utils/speechio/textnorm_zh.py @@ -0,0 +1,1204 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 - 2022 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import sys, os, argparse +import string, re +import csv + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ['呃', '啊'] + +ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \ + '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \ + '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \ + '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)' +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + + +# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CN_PUNCS_STOP = '!?。。' +CN_PUNCS_NONSTOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-' +CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP + +PUNCS = CN_PUNCS + string.punctuation +PUNCS_TRANSFORM = str.maketrans(PUNCS, ' ' * len(PUNCS), '') # replace puncs with space + + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + ' ': ' ', + '!': '!', + '"': '"', + '#': '#', + '$': '$', + '%': '%', + '&': '&', + ''': "'", + '(': '(', + ')': ')', + '*': '*', + '+': '+', + ',': ',', + '-': '-', + '.': '.', + '/': '/', + '0': '0', + '1': '1', + '2': '2', + '3': '3', + '4': '4', + '5': '5', + '6': '6', + '7': '7', + '8': '8', + '9': '9', + ':': ':', + ';': ';', + '<': '<', + '=': '=', + '>': '>', + '?': '?', + '@': '@', + 'A': 'A', + 'B': 'B', + 'C': 'C', + 'D': 'D', + 'E': 'E', + 'F': 'F', + 'G': 'G', + 'H': 'H', + 'I': 'I', + 'J': 'J', + 'K': 'K', + 'L': 'L', + 'M': 'M', + 'N': 'N', + 'O': 'O', + 'P': 'P', + 'Q': 'Q', + 'R': 'R', + 'S': 'S', + 'T': 'T', + 'U': 'U', + 'V': 'V', + 'W': 'W', + 'X': 'X', + 'Y': 'Y', + 'Z': 'Z', + '[': '[', + '\': '\\', + ']': ']', + '^': '^', + '_': '_', + '`': '`', + 'a': 'a', + 'b': 'b', + 'c': 'c', + 'd': 'd', + 'e': 'e', + 'f': 'f', + 'g': 'g', + 'h': 'h', + 'i': 'i', + 'j': 'j', + 'k': 'k', + 'l': 'l', + 'm': 'm', + 'n': 'n', + 'o': 'o', + 'p': 'p', + 'q': 'q', + 'r': 'r', + 's': 's', + 't': 't', + 'u': 'u', + 'v': 'v', + 'w': 'w', + 'x': 'x', + 'y': 'y', + 'z': 'z', + '{': '{', + '|': '|', + '}': '}', + '~': '~', +} +QJ2BJ_TRANSFORM = str.maketrans(''.join(QJ2BJ.keys()), ''.join(QJ2BJ.values()), '') + + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: +# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total +CN_CHARS_COMMON = ( + '一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举' + '乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互' + '亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从' + '仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优' + '伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚' + '佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣' + '侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯' + '俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌' + '偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚' + '僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六' + '兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况' + '冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈' + '刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐' + '剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼' + '劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹' + '区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵' + '卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔' + '叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊' + '同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆' + '呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎' + '咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌' + '响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛' + '唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴' + '啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌' + '嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡' + '嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓' + '嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢' + '圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥' + '坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩' + '垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基' + '埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填' + '塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复' + '夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖' + '套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮' + '妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈' + '娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻' + '婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱' + '嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽' + '宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾' + '宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝' + '尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山' + '屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃' + '峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧' + '崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉' + '巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡' + '带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐' + '庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷' + '建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖' + '彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循' + '徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀' + '态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓' + '恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟' + '悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭' + '惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥' + '慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我' + '戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔' + '托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡' + '抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥' + '拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫' + '振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎' + '掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭' + '揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘' + '摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢' + '擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整' + '敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗' + '旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝' + '星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡' + '晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜' + '曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽' + '杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅' + '枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔' + '柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩' + '株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯' + '桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘' + '棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂' + '楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱' + '榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞' + '橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正' + '此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒' + '毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮' + '氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽' + '汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽' + '沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼' + '泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿' + '流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸' + '浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅' + '淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟' + '渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆' + '溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕' + '滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶' + '漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧' + '澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵' + '灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈' + '烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯' + '焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵' + '熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍' + '牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷' + '犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎' + '猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎' + '玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊' + '珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊' + '琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙' + '瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱' + '璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯' + '田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐' + '疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒' + '痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩' + '瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙' + '皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾' + '省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦' + '睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知' + '矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮' + '砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍' + '碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷' + '磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭' + '祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣' + '秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗' + '穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立' + '竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯' + '笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅' + '箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾' + '簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮' + '粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢' + '縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁' + '绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩' + '绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕' + '编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐' + '网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲' + '羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者' + '耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋' + '职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸' + '肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳' + '胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒' + '腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻' + '臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般' + '舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊' + '芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈' + '苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆' + '茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐' + '荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛' + '莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥' + '菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著' + '葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺' + '蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷' + '蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸' + '薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱' + '虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆' + '蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗' + '蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃' + '螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡' + '蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒' + '袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂' + '褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉' + '觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认' + '讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词' + '诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请' + '诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡' + '谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹' + '豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼' + '贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤' + '赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑' + '跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮' + '踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇' + '躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较' + '辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄' + '迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆' + '选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒' + '道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱' + '邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴' + '郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝' + '酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭' + '醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒' + '钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼' + '钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨' + '铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐' + '锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹' + '锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣' + '镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼' + '闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶' + '阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈' + '隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳' + '零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰' + '靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵' + '韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额' + '颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰' + '饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥' + '馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑' + '骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高' + '髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾' + '鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨' + '鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓' + '鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶' + '鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟' + '鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄' + '黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷' + '鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃' + '㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡' + '䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽' + '𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯' + '𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟' + '𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟' + '𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟' + '𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓' +) +CN_CHARS_EXT = '吶诶屌囧飚屄' + +CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT +IN_CH_CHARS = { c : True for c in CN_CHARS } + +EN_CHARS = string.ascii_letters + string.digits +IN_EN_CHARS = { c : True for c in EN_CHARS } + +VALID_CHARS = CN_CHARS + EN_CHARS + ' ' +IN_VALID_CHARS = { c : True for c in VALID_CHARS } + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + #self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit(power=index + 1, + simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit(power=index + 8, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit(power=(index + 2) * 4, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit(power=pow(2, index + 3), + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + else: + raise ValueError( + 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', lambda x, + y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * + pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, + traditional=False, alt_zero=False, alt_one=False, alt_two=True, + use_zeros=True, use_units=True): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed( + system.units) if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +def normalize_nsw(raw_text): + text = '^' + raw_text + '$' + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + #print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + #print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + #print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + #print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + #print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + #print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + #print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + #print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + + # restore P2P, O2O, B2C, B2B etc + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1) + + return text.lstrip('^').rstrip('$') + + +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str='' + while re.search('儿',text): + a = re.search('儿',text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0 : + new_str = new_str + text[0:a[0]] + text = text[a[1]:] + else: + new_str = new_str + text[0:b[1]] + text = text[b[1]:] + + text = new_str + text + return text + + +def remove_space(text): + tokens = text.split() + new = [] + for k,t in enumerate(tokens): + if k != 0: + if IN_EN_CHARS.get(tokens[k-1][-1]) and IN_EN_CHARS.get(t[0]): + new.append(' ') + new.append(t) + return ''.join(new) + + +class TextNorm: + def __init__(self, + to_banjiao:bool = False, + to_upper:bool = False, + to_lower:bool = False, + remove_fillers:bool = False, + remove_erhua:bool = False, + check_chars:bool = False, + remove_space:bool = False, + cc_mode:str = '', + ) : + self.to_banjiao = to_banjiao + self.to_upper = to_upper + self.to_lower = to_lower + self.remove_fillers = remove_fillers + self.remove_erhua = remove_erhua + self.check_chars = check_chars + self.remove_space = remove_space + + self.cc = None + if cc_mode: + from opencc import OpenCC # Open Chinese Convert: pip install opencc + self.cc = OpenCC(cc_mode) + + def __call__(self, text): + if self.cc: + text = self.cc.convert(text) + + if self.to_banjiao: + text = text.translate(QJ2BJ_TRANSFORM) + + if self.to_upper: + text = text.upper() + + if self.to_lower: + text = text.lower() + + if self.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, '') + + if self.remove_erhua: + text = remove_erhua(text) + + text = normalize_nsw(text) + + text = text.translate(PUNCS_TRANSFORM) + + if self.check_chars: + for c in text: + if not IN_VALID_CHARS.get(c): + print(f'WARNING: illegal char {c} in: {text}', file=sys.stderr) + return '' + + if self.remove_space: + text = remove_space(text) + + return text + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + + # normalizer options + p.add_argument('--to_banjiao', action='store_true', help='convert quanjiao chars to banjiao') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--remove_fillers', action='store_true', help='remove filler chars such as "呃, 啊"') + p.add_argument('--remove_erhua', action='store_true', help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"') + p.add_argument('--check_chars', action='store_true' , help='skip sentences containing illegal chars') + p.add_argument('--remove_space', action='store_true' , help='remove whitespace') + p.add_argument('--cc_mode', choices=['', 't2s', 's2t'], default='', help='convert between traditional to simplified') + + # I/O options + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + p.add_argument('--has_key', action='store_true', help="will be deprecated, set --format ark instead") + p.add_argument('--format', type=str, choices=['txt', 'ark', 'tsv'], default='txt', help='input format') + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + + args = p.parse_args() + + if args.has_key: + args.format = 'ark' + + normalizer = TextNorm( + to_banjiao = args.to_banjiao, + to_upper = args.to_upper, + to_lower = args.to_lower, + remove_fillers = args.remove_fillers, + remove_erhua = args.remove_erhua, + check_chars = args.check_chars, + remove_space = args.remove_space, + cc_mode = args.cc_mode, + ) + + ndone = 0 + with open(args.ifile, 'r', encoding = 'utf8') as istream, open(args.ofile, 'w+', encoding = 'utf8') as ostream: + if args.format == 'tsv': + reader = csv.DictReader(istream, delimiter = '\t') + assert('TEXT' in reader.fieldnames) + print('\t'.join(reader.fieldnames), file=ostream) + + for item in reader: + text = item['TEXT'] + + if text: + text = normalizer(text) + + if text: + item['TEXT'] = text + print('\t'.join([ item[f] for f in reader.fieldnames ]), file = ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True) + else: + for l in istream: + key, text = '', '' + if args.format == 'ark': # KALDI archive, line format: "key text" + cols = l.strip().split(maxsplit=1) + key, text = cols[0], cols[1] if len(cols) == 2 else '' + else: + text = l.strip() + + if text: + text = normalizer(text) + + if text: + if args.format == 'ark': + print(key + '\t' + text, file = ostream) + else: + print(text, file = ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True) + print(f'text norm: {ndone} lines done in total.', file = sys.stderr, flush = True) + diff --git a/utils/tokenizer.py b/utils/tokenizer.py new file mode 100644 index 0000000..5954039 --- /dev/null +++ b/utils/tokenizer.py @@ -0,0 +1,160 @@ +import os +import subprocess +from enum import Enum +from typing import List + +from utils.logger import logger + + +class TokenizerType(str, Enum): + word = "word" + whitespace = "whitespace" + + +class LangType(str, Enum): + zh = "zh" + en = "en" + + +TOKENIZER_MAPPING = dict() +TOKENIZER_MAPPING['zh'] = TokenizerType.word +TOKENIZER_MAPPING['en'] = TokenizerType.whitespace +TOKENIZER_MAPPING['ru'] = TokenizerType.whitespace +TOKENIZER_MAPPING['ar'] = TokenizerType.whitespace +TOKENIZER_MAPPING['tr'] = TokenizerType.whitespace +TOKENIZER_MAPPING['es'] = TokenizerType.whitespace +TOKENIZER_MAPPING['pt'] = TokenizerType.whitespace +TOKENIZER_MAPPING['id'] = TokenizerType.whitespace +TOKENIZER_MAPPING['he'] = TokenizerType.whitespace +TOKENIZER_MAPPING['ja'] = TokenizerType.word +TOKENIZER_MAPPING['pl'] = TokenizerType.whitespace +TOKENIZER_MAPPING['de'] = TokenizerType.whitespace +TOKENIZER_MAPPING['fr'] = TokenizerType.whitespace +TOKENIZER_MAPPING['nl'] = TokenizerType.whitespace +TOKENIZER_MAPPING['el'] = TokenizerType.whitespace +TOKENIZER_MAPPING['vi'] = TokenizerType.whitespace +TOKENIZER_MAPPING['th'] = TokenizerType.whitespace +TOKENIZER_MAPPING['it'] = TokenizerType.whitespace +TOKENIZER_MAPPING['fa'] = TokenizerType.whitespace +TOKENIZER_MAPPING['ti'] = TokenizerType.word + +import nltk + +import re +from nltk.tokenize import word_tokenize +from nltk.stem import WordNetLemmatizer +lemmatizer = WordNetLemmatizer() + + +class Tokenizer: + @classmethod + def norm_and_tokenize(cls, sentences: List[str], lang: str = None): + tokenizer = TOKENIZER_MAPPING.get(lang, None) + sentences = cls.replace_general_punc(sentences, tokenizer) + sentences = cls.norm(sentences, lang) + return cls.tokenize(sentences, lang) + + @classmethod + def tokenize(cls, sentences: List[str], lang: str = None): + tokenizer = TOKENIZER_MAPPING.get(lang, None) + # sentences = cls.replace_general_punc(sentences, tokenizer) + if tokenizer == TokenizerType.word: + return [[ch for ch in sentence] for sentence in sentences] + elif tokenizer == TokenizerType.whitespace: + return [re.findall(r"\w+", sentence.lower()) for sentence in sentences] + else: + logger.error("找不到对应的分词器") + exit(-1) + + @classmethod + def norm(cls, sentences: List[str], lang: LangType = None): + if lang == "zh": + from utils.speechio import textnorm_zh as textnorm + + normalizer = textnorm.TextNorm( + to_banjiao=True, + to_upper=True, + to_lower=False, + remove_fillers=True, + remove_erhua=False, # 这里同批量识别不同,改成了 False + check_chars=False, + remove_space=False, + cc_mode="", + ) + return [normalizer(sentence) for sentence in sentences] + elif lang == "en": + # pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + # with open('./predict.txt', 'w', encoding='utf-8') as fp: + # for idx, sentence in enumerate(sentences): + # fp.write('%s\t%s\n' % (idx, sentence)) + # subprocess.run( + # f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt', + # shell=True, + # check=True, + # ) + # sentence_norm = [] + # with open('./predict_norm.txt', 'r', encoding='utf-8') as fp: + # for line in fp.readlines(): + # line_split_result = line.strip().split('\t', 1) + # if len(line_split_result) >= 2: + # sentence_norm.append(line_split_result[1]) + # else: + # sentence_norm.append("") + # # 有可能没有 norm 后就没了 + # return sentence_norm + + # sentence_norm = [] + # for sentence in sentences: + # doc = _nlp_en(sentence) + # # 保留单词,去除标点、数字、特殊符号;做词形还原 + # tokens = [token.lemma_ for token in doc if token.is_alpha] + # tokens = [t.upper() for t in tokens] # 根据你的原逻辑 to_upper=True + # sentence_norm.append(" ".join(tokens)) + # return sentence_norm + result = [] + for sentence in sentences: + sentence = re.sub(r"[^a-zA-Z\s]", "", sentence) + tokens = word_tokenize(sentence) + tokens = [lemmatizer.lemmatize(t) for t in tokens] + # if to_upper: + # tokens = [t.upper() for t in tokens] + result.append(" ".join(tokens)) + return result + else: + punc = "!?。"#$%&'()*+,-/:;<=>[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏.`! #$%^&*()_+-=|';\":/.,?><~·!#¥%……&*()——+-=“:’;、。,?》《{}" + return [sentence.translate(str.maketrans(dict.fromkeys(punc, " "))).lower() for sentence in sentences] + + @classmethod + def replace_general_punc(cls, sentences: List[str], tokenizer: TokenizerType,language:str = None) -> List[str]: + """代替原来的函数 utils.metrics.cut_sentence""" + if language: + tokenizer = TOKENIZER_MAPPING.get(language) + general_puncs = [ + "······", + "......", + "。", + ",", + "?", + "!", + ";", + ":", + "...", + ".", + ",", + "?", + "!", + ";", + ":", + ] + if tokenizer == TokenizerType.whitespace: + replacer = " " + else: + replacer = "" + trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer)) + ret_sentences = [""] * len(sentences) + for i, sentence in enumerate(sentences): + sentence = sentence.translate(trans) + sentence = sentence.strip() + sentence = sentence.lower() + ret_sentences[i] = sentence + return ret_sentences