initial commit

This commit is contained in:
2025-08-20 14:07:57 +08:00
commit 77d1870987
79 changed files with 13816 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
# Text Normalization
Text Normalization is part of NeMo's `nemo_text_processing` - a Python package that is installed with the `nemo_toolkit`.
It converts text from written form into its verbalized form, e.g. "123" -> "one hundred twenty three".
See [NeMo documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/text_normalization/wfst/wfst_text_normalization.html) for details.
Tutorial with overview of the package capabilities: [Text_(Inverse)_Normalization.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb)
Tutorial on how to customize the underlying gramamrs: [WFST_Tutorial.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/WFST_Tutorial.ipynb)

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,350 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import string
from collections import defaultdict, namedtuple
from typing import Dict, List, Optional, Set, Tuple
from unicodedata import category
EOS_TYPE = "EOS"
PUNCT_TYPE = "PUNCT"
PLAIN_TYPE = "PLAIN"
Instance = namedtuple('Instance', 'token_type un_normalized normalized')
known_types = [
"PLAIN",
"DATE",
"CARDINAL",
"LETTERS",
"VERBATIM",
"MEASURE",
"DECIMAL",
"ORDINAL",
"DIGIT",
"MONEY",
"TELEPHONE",
"ELECTRONIC",
"FRACTION",
"TIME",
"ADDRESS",
]
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
"""
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
E.g.
PLAIN Brillantaisia <self>
PLAIN is <self>
PLAIN a <self>
PLAIN genus <self>
PLAIN of <self>
PLAIN plant <self>
PLAIN in <self>
PLAIN family <self>
PLAIN Acanthaceae <self>
PUNCT . sil
<eos> <eos>
Args:
file_path: file path to text file
Returns: flat list of instances
"""
res = []
with open(file_path, 'r') as fp:
for line in fp:
parts = line.strip().split("\t")
if parts[0] == "<eos>":
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
else:
l_type, l_token, l_normalized = parts
l_token = l_token.lower()
l_normalized = l_normalized.lower()
if l_type == PLAIN_TYPE:
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
elif l_type != PUNCT_TYPE:
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
return res
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
"""
Load given list of text files using the `load_func` function.
Args:
file_paths: list of file paths
load_func: loading function
Returns: flat list of instances
"""
res = []
for file_path in file_paths:
res.extend(load_func(file_path=file_path))
return res
def clean_generic(text: str) -> str:
"""
Cleans text without affecting semiotic classes.
Args:
text: string
Returns: cleaned string
"""
text = text.strip()
text = text.lower()
return text
def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
"""
Evaluates accuracy given predictions and labels.
Args:
preds: predictions
labels: labels
input: optional, only needed for verbosity
verbose: if true prints [input], golden labels and predictions
Returns accuracy
"""
acc = 0
nums = len(preds)
for i in range(nums):
pred_norm = clean_generic(preds[i])
label_norm = clean_generic(labels[i])
if pred_norm == label_norm:
acc = acc + 1
else:
if input:
print(f"inpu: {json.dumps(input[i])}")
print(f"gold: {json.dumps(label_norm)}")
print(f"pred: {json.dumps(pred_norm)}")
return acc / nums
def training_data_to_tokens(
data: List[Instance], category: Optional[str] = None
) -> Dict[str, Tuple[List[str], List[str]]]:
"""
Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
Args:
data: list of instances
category: optional semiotic class category name
Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
"""
result = defaultdict(lambda: ([], []))
for instance in data:
if instance.token_type != EOS_TYPE:
if category is None or instance.token_type == category:
result[instance.token_type][0].append(instance.un_normalized)
result[instance.token_type][1].append(instance.normalized)
return result
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
"""
Takes instance list, creates list of sentences split by EOS_Token
Args:
data: list of instances
Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
"""
# split data at EOS boundaries
sentences = []
sentence = []
categories = []
sentence_categories = set()
for instance in data:
if instance.token_type == EOS_TYPE:
sentences.append(sentence)
sentence = []
categories.append(sentence_categories)
sentence_categories = set()
else:
sentence.append(instance)
sentence_categories.update([instance.token_type])
un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
return un_normalized, normalized, categories
def post_process_punctuation(text: str) -> str:
"""
Normalized quotes and spaces
Args:
text: text
Returns: text with normalized spaces and quotes
"""
text = (
text.replace('( ', '(')
.replace(' )', ')')
.replace('{ ', '{')
.replace(' }', '}')
.replace('[ ', '[')
.replace(' ]', ']')
.replace(' ', ' ')
.replace('', '"')
.replace("", "'")
.replace("»", '"')
.replace("«", '"')
.replace("\\", "")
.replace("", '"')
.replace("´", "'")
.replace("", "'")
.replace('', '"')
.replace("", "'")
.replace('`', "'")
.replace('- -', "--")
)
for punct in "!,.:;?":
text = text.replace(f' {punct}', punct)
return text.strip()
def pre_process(text: str) -> str:
"""
Optional text preprocessing before normalization (part of TTS TN pipeline)
Args:
text: string that may include semiotic classes
Returns: text with spaces around punctuation marks
"""
space_both = '[]'
for punct in space_both:
text = text.replace(punct, ' ' + punct + ' ')
# remove extra space
text = re.sub(r' +', ' ', text)
return text
def load_file(file_path: str) -> List[str]:
"""
Loads given text file with separate lines into list of string.
Args:
file_path: file path
Returns: flat list of string
"""
res = []
with open(file_path, 'r') as fp:
for line in fp:
res.append(line)
return res
def write_file(file_path: str, data: List[str]):
"""
Writes out list of string to file.
Args:
file_path: file path
data: list of string
"""
with open(file_path, 'w') as fp:
for line in data:
fp.write(line + '\n')
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
"""
Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
After NN normalization, Moses detokenization puts a space after
punctuation marks, and attaches an opening quote "'" to the word to the right.
E.g., input to the TN NN model is "12 test' example",
after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
but it doesn't match the input and can cause issues during TTS voice generation.)
The current function will match the punctuation and spaces of the normalized text with the input sequence.
"12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
Args:
input: input text (original input to the NN, before normalization or tokenization)
normalized_text: output text (output of the TN NN model)
add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
"""
# in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
# this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
# to make sure these new double quotes work with this function
if "``" in input and "``" not in normalized_text:
input = input.replace("``", '"')
input = [x for x in input]
normalized_text = [x for x in normalized_text]
punct_marks = [x for x in string.punctuation if x in input]
if add_unicode_punct:
punct_unicode = [
chr(i)
for i in range(sys.maxunicode)
if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
]
punct_marks = punct_marks.extend(punct_unicode)
for punct in punct_marks:
try:
equal = True
if input.count(punct) != normalized_text.count(punct):
equal = False
idx_in, idx_out = 0, 0
while punct in input[idx_in:]:
idx_out = normalized_text.index(punct, idx_out)
idx_in = input.index(punct, idx_in)
def _is_valid(idx_out, idx_in, normalized_text, input):
"""Check if previous or next word match (for cases when punctuation marks are part of
semiotic token, i.e. some punctuation can be missing in the normalized text)"""
return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
idx_out < len(normalized_text) - 1
and idx_in < len(input) - 1
and normalized_text[idx_out + 1] == input[idx_in + 1]
)
if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
idx_in += 1
continue
if idx_in > 0 and idx_out > 0:
if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
normalized_text[idx_out - 1] = ""
elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
normalized_text[idx_out - 1] += " "
if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
normalized_text[idx_out + 1] = ""
elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
normalized_text[idx_out] = normalized_text[idx_out] + " "
idx_out += 1
idx_in += 1
except:
pass
normalized_text = "".join(normalized_text)
return re.sub(r' +', ' ', normalized_text)

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst

View File

@@ -0,0 +1,342 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from typing import List
import regex as re
from nemo_text_processing.text_normalization.data_loader_utils import (
EOS_TYPE,
Instance,
load_files,
training_data_to_sentences,
)
"""
This file is for evaluation purposes.
filter_loaded_data() cleans data (list of instances) for text normalization. Filters and cleaners can be specified for each semiotic class individually.
For example, normalized text should only include characters and whitespace characters but no punctuation.
Cardinal unnormalized instances should contain at least one integer and all other characters are removed.
"""
class Filter:
"""
Filter class
Args:
class_type: semiotic class used in dataset
process_func: function to transform text
filter_func: function to filter text
"""
def __init__(self, class_type: str, process_func: object, filter_func: object):
self.class_type = class_type
self.process_func = process_func
self.filter_func = filter_func
def filter(self, instance: Instance) -> bool:
"""
filter function
Args:
filters given instance with filter function
Returns: True if given instance fulfills criteria or does not belong to class type
"""
if instance.token_type != self.class_type:
return True
return self.filter_func(instance)
def process(self, instance: Instance) -> Instance:
"""
process function
Args:
processes given instance with process function
Returns: processed instance if instance belongs to expected class type or original instance
"""
if instance.token_type != self.class_type:
return instance
return self.process_func(instance)
def filter_cardinal_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_cardinal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_ordinal_1(instance: Instance) -> bool:
ok = re.search(r"(st|nd|rd|th)\s*$", instance.un_normalized)
return ok
def process_ordinal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r"[,\s]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_decimal_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_decimal_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_measure_1(instance: Instance) -> bool:
ok = True
return ok
def process_measure_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r",", "", un_normalized)
un_normalized = re.sub(r"m2", "", un_normalized)
un_normalized = re.sub(r"(\d)([^\d.\s])", r"\1 \2", un_normalized)
normalized = re.sub(r"[^a-z\s]", "", normalized)
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_money_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_money_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
un_normalized = re.sub(r",", "", un_normalized)
un_normalized = re.sub(r"a\$", r"$", un_normalized)
un_normalized = re.sub(r"us\$", r"$", un_normalized)
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_time_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_time_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r": ", ":", un_normalized)
un_normalized = re.sub(r"(\d)\s?a\s?m\s?", r"\1 a.m.", un_normalized)
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_plain_1(instance: Instance) -> bool:
ok = True
return ok
def process_plain_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_punct_1(instance: Instance) -> bool:
ok = True
return ok
def process_punct_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_date_1(instance: Instance) -> bool:
ok = True
return ok
def process_date_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_letters_1(instance: Instance) -> bool:
ok = True
return ok
def process_letters_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_verbatim_1(instance: Instance) -> bool:
ok = True
return ok
def process_verbatim_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_digit_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_digit_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_telephone_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_telephone_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_electronic_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_electronic_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_fraction_1(instance: Instance) -> bool:
ok = re.search(r"[0-9]", instance.un_normalized)
return ok
def process_fraction_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
def filter_address_1(instance: Instance) -> bool:
ok = True
return ok
def process_address_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
filters = []
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Instance]:
"""
Filters list of instances
Args:
data: list of instances
Returns: filtered and transformed list of instances
"""
updates_instances = []
for instance in data:
updated_instance = False
for fil in filters:
if fil.class_type == instance.token_type and fil.filter(instance):
instance = fil.process(instance)
updated_instance = True
if updated_instance:
if verbose:
print(instance)
updates_instances.append(instance)
return updates_instances
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
file_path = args.input
print("Loading training data: " + file_path)
instance_list = load_files([file_path]) # List of instances
filtered_instance_list = filter_loaded_data(instance_list, args.verbose)
training_data_to_sentences(filtered_instance_list)

View File

@@ -0,0 +1,196 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright 2015 and onwards Google, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import string
from pathlib import Path
from typing import Dict
import pynini
from nemo_text_processing.text_normalization.en.utils import get_abs_path
from pynini import Far
from pynini.examples import plurals
from pynini.export import export
from pynini.lib import byte, pynutil, utf8
NEMO_CHAR = utf8.VALID_UTF8_CHAR
NEMO_DIGIT = byte.DIGIT
NEMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
NEMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
NEMO_ALPHA = pynini.union(NEMO_LOWER, NEMO_UPPER).optimize()
NEMO_ALNUM = pynini.union(NEMO_DIGIT, NEMO_ALPHA).optimize()
NEMO_HEX = pynini.union(*string.hexdigits).optimize()
NEMO_NON_BREAKING_SPACE = u"\u00A0"
NEMO_SPACE = " "
NEMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize()
NEMO_NOT_SPACE = pynini.difference(NEMO_CHAR, NEMO_WHITE_SPACE).optimize()
NEMO_NOT_QUOTE = pynini.difference(NEMO_CHAR, r'"').optimize()
NEMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
NEMO_GRAPH = pynini.union(NEMO_ALNUM, NEMO_PUNCT).optimize()
NEMO_SIGMA = pynini.closure(NEMO_CHAR)
delete_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE))
delete_zero_or_one_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE, 0, 1))
insert_space = pynutil.insert(" ")
delete_extra_space = pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 1), " ")
delete_preserve_order = pynini.closure(
pynutil.delete(" preserve_order: true")
| (pynutil.delete(" field_order: \"") + NEMO_NOT_QUOTE + pynutil.delete("\""))
)
suppletive = pynini.string_file(get_abs_path("data/suppletive.tsv"))
# _v = pynini.union("a", "e", "i", "o", "u")
_c = pynini.union(
"b", "c", "d", "f", "g", "h", "j", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "x", "y", "z"
)
_ies = NEMO_SIGMA + _c + pynini.cross("y", "ies")
_es = NEMO_SIGMA + pynini.union("s", "sh", "ch", "x", "z") + pynutil.insert("es")
_s = NEMO_SIGMA + pynutil.insert("s")
graph_plural = plurals._priority_union(
suppletive, plurals._priority_union(_ies, plurals._priority_union(_es, _s, NEMO_SIGMA), NEMO_SIGMA), NEMO_SIGMA
).optimize()
SINGULAR_TO_PLURAL = graph_plural
PLURAL_TO_SINGULAR = pynini.invert(graph_plural)
TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
TO_UPPER = pynini.invert(TO_LOWER)
MIN_NEG_WEIGHT = -0.0001
MIN_POS_WEIGHT = 0.0001
def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']):
"""
Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
Args:
file_name: exported file name
graphs: Mapping of a rule name and Pynini WFST graph to be exported
"""
exporter = export.Exporter(file_name)
for rule, graph in graphs.items():
exporter[rule] = graph.optimize()
exporter.close()
print(f'Created {file_name}')
def get_plurals(fst):
"""
Given singular returns plurals
Args:
fst: Fst
Returns plurals to given singular forms
"""
return SINGULAR_TO_PLURAL @ fst
def get_singulars(fst):
"""
Given plural returns singulars
Args:
fst: Fst
Returns singulars to given plural forms
"""
return PLURAL_TO_SINGULAR @ fst
def convert_space(fst) -> 'pynini.FstLike':
"""
Converts space to nonbreaking space.
Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
Args:
fst: input fst
Returns output fst where breaking spaces are converted to non breaking spaces
"""
return fst @ pynini.cdrewrite(pynini.cross(NEMO_SPACE, NEMO_NON_BREAKING_SPACE), "", "", NEMO_SIGMA)
class GraphFst:
"""
Base class for all grammar fsts.
Args:
name: name of grammar class
kind: either 'classify' or 'verbalize'
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, name: str, kind: str, deterministic: bool = True):
self.name = name
self.kind = str
self._fst = None
self.deterministic = deterministic
self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far')
if self.far_exist():
self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
def far_exist(self) -> bool:
"""
Returns true if FAR can be loaded
"""
return self.far_path.exists()
@property
def fst(self) -> 'pynini.FstLike':
return self._fst
@fst.setter
def fst(self, fst):
self._fst = fst
def add_tokens(self, fst) -> 'pynini.FstLike':
"""
Wraps class name around to given fst
Args:
fst: input fst
Returns:
Fst: fst
"""
return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
def delete_tokens(self, fst) -> 'pynini.FstLike':
"""
Deletes class name wrap around output of given fst
Args:
fst: input fst
Returns:
Fst: fst
"""
res = (
pynutil.delete(f"{self.name}")
+ delete_space
+ pynutil.delete("{")
+ delete_space
+ fst
+ delete_space
+ pynutil.delete("}")
)
return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,50 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_UPPER, GraphFst, insert_space
from pynini.lib import pynutil
class AbbreviationFst(GraphFst):
"""
Finite state transducer for classifying electronic: as URLs, email addresses, etc.
e.g. "ABC" -> tokens { abbreviation { value: "A B C" } }
Args:
whitelist: whitelist FST
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, whitelist: 'pynini.FstLike', deterministic: bool = True):
super().__init__(name="abbreviation", kind="classify", deterministic=deterministic)
dot = pynini.accep(".")
# A.B.C. -> A. B. C.
graph = NEMO_UPPER + dot + pynini.closure(insert_space + NEMO_UPPER + dot, 1)
# A.B.C. -> A.B.C.
graph |= NEMO_UPPER + dot + pynini.closure(NEMO_UPPER + dot, 1)
# ABC -> A B C
graph |= NEMO_UPPER + pynini.closure(insert_space + NEMO_UPPER, 1)
# exclude words that are included in the whitelist
graph = pynini.compose(
pynini.difference(pynini.project(graph, "input"), pynini.project(whitelist.graph, "input")), graph
)
graph = pynutil.insert("value: \"") + graph.optimize() + pynutil.insert("\"")
graph = self.add_tokens(graph)
self.fst = graph.optimize()

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_DIGIT,
NEMO_NOT_QUOTE,
NEMO_SIGMA,
GraphFst,
insert_space,
)
from nemo_text_processing.text_normalization.en.taggers.date import get_four_digit_year_graph
from nemo_text_processing.text_normalization.en.utils import get_abs_path
from pynini.examples import plurals
from pynini.lib import pynutil
class CardinalFst(GraphFst):
"""
Finite state transducer for classifying cardinals, e.g.
-23 -> cardinal { negative: "true" integer: "twenty three" } }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True, lm: bool = False):
super().__init__(name="cardinal", kind="classify", deterministic=deterministic)
self.lm = lm
self.deterministic = deterministic
# TODO replace to have "oh" as a default for "0"
graph = pynini.Far(get_abs_path("data/number/cardinal_number_name.far")).get_fst()
self.graph_hundred_component_at_least_one_none_zero_digit = (
pynini.closure(NEMO_DIGIT, 2, 3) | pynini.difference(NEMO_DIGIT, pynini.accep("0"))
) @ graph
graph_digit = pynini.string_file(get_abs_path("data/number/digit.tsv"))
graph_zero = pynini.string_file(get_abs_path("data/number/zero.tsv"))
single_digits_graph = pynini.invert(graph_digit | graph_zero)
self.single_digits_graph = single_digits_graph + pynini.closure(insert_space + single_digits_graph)
if not deterministic:
# for a single token allow only the same normalization
# "007" -> {"oh oh seven", "zero zero seven"} not {"oh zero seven"}
single_digits_graph_zero = pynini.invert(graph_digit | graph_zero)
single_digits_graph_oh = pynini.invert(graph_digit) | pynini.cross("0", "oh")
self.single_digits_graph = single_digits_graph_zero + pynini.closure(
insert_space + single_digits_graph_zero
)
self.single_digits_graph |= single_digits_graph_oh + pynini.closure(insert_space + single_digits_graph_oh)
single_digits_graph_with_commas = pynini.closure(
self.single_digits_graph + insert_space, 1, 3
) + pynini.closure(
pynutil.delete(",")
+ single_digits_graph
+ insert_space
+ single_digits_graph
+ insert_space
+ single_digits_graph,
1,
)
optional_minus_graph = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
graph = (
pynini.closure(NEMO_DIGIT, 1, 3)
+ (pynini.closure(pynutil.delete(",") + NEMO_DIGIT ** 3) | pynini.closure(NEMO_DIGIT ** 3))
) @ graph
self.graph = graph
self.graph_with_and = self.add_optional_and(graph)
if deterministic:
long_numbers = pynini.compose(NEMO_DIGIT ** (5, ...), self.single_digits_graph).optimize()
final_graph = plurals._priority_union(long_numbers, self.graph_with_and, NEMO_SIGMA).optimize()
cardinal_with_leading_zeros = pynini.compose(
pynini.accep("0") + pynini.closure(NEMO_DIGIT), self.single_digits_graph
)
final_graph |= cardinal_with_leading_zeros
else:
leading_zeros = pynini.compose(pynini.closure(pynini.accep("0"), 1), self.single_digits_graph)
cardinal_with_leading_zeros = (
leading_zeros + pynutil.insert(" ") + pynini.compose(pynini.closure(NEMO_DIGIT), self.graph_with_and)
)
# add small weight to non-default graphs to make sure the deterministic option is listed first
final_graph = (
self.graph_with_and
| pynutil.add_weight(self.single_digits_graph, 0.0001)
| get_four_digit_year_graph() # allows e.g. 4567 be pronouced as forty five sixty seven
| pynutil.add_weight(single_digits_graph_with_commas, 0.0001)
| cardinal_with_leading_zeros
)
final_graph = optional_minus_graph + pynutil.insert("integer: \"") + final_graph + pynutil.insert("\"")
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()
def add_optional_and(self, graph):
graph_with_and = graph
if not self.lm:
graph_with_and = pynutil.add_weight(graph, 0.00001)
not_quote = pynini.closure(NEMO_NOT_QUOTE)
no_thousand_million = pynini.difference(
not_quote, not_quote + pynini.union("thousand", "million") + not_quote
).optimize()
integer = (
not_quote + pynutil.add_weight(pynini.cross("hundred ", "hundred and ") + no_thousand_million, -0.0001)
).optimize()
no_hundred = pynini.difference(NEMO_SIGMA, not_quote + pynini.accep("hundred") + not_quote).optimize()
integer |= (
not_quote + pynutil.add_weight(pynini.cross("thousand ", "thousand and ") + no_hundred, -0.0001)
).optimize()
optional_hundred = pynini.compose((NEMO_DIGIT - "0") ** 3, graph).optimize()
optional_hundred = pynini.compose(optional_hundred, NEMO_SIGMA + pynini.cross(" hundred", "") + NEMO_SIGMA)
graph_with_and |= pynini.compose(graph, integer).optimize()
graph_with_and |= optional_hundred
return graph_with_and

View File

@@ -0,0 +1,370 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_CHAR,
NEMO_DIGIT,
NEMO_LOWER,
NEMO_SIGMA,
NEMO_NOT_QUOTE,
TO_LOWER,
GraphFst,
delete_extra_space,
delete_space,
insert_space,
)
from nemo_text_processing.text_normalization.en.utils import (
augment_labels_with_punct_at_end,
get_abs_path,
load_labels,
)
from pynini.examples import plurals
from pynini.lib import pynutil
graph_teen = pynini.invert(pynini.string_file(get_abs_path("data/number/teen.tsv"))).optimize()
graph_digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize()
ties_graph = pynini.invert(pynini.string_file(get_abs_path("data/number/ty.tsv"))).optimize()
year_suffix = load_labels(get_abs_path("data/date/year_suffix.tsv"))
year_suffix.extend(augment_labels_with_punct_at_end(year_suffix))
year_suffix = pynini.string_map(year_suffix).optimize()
def get_ties_graph(deterministic: bool = True):
"""
Returns two digit transducer, e.g.
03 -> o three
12 -> thirteen
20 -> twenty
"""
graph = graph_teen | ties_graph + pynutil.delete("0") | ties_graph + insert_space + graph_digit
if deterministic:
graph = graph | pynini.cross("0", "o") + insert_space + graph_digit
else:
graph = graph | (pynini.cross("0", "o") | pynini.cross("0", "zero")) + insert_space + graph_digit
return graph.optimize()
def get_four_digit_year_graph(deterministic: bool = True):
"""
Returns a four digit transducer which is combination of ties/teen or digits
(using hundred instead of thousand format), e.g.
1219 -> twelve nineteen
3900 -> thirty nine hundred
"""
graph_ties = get_ties_graph(deterministic)
graph_with_s = (
(graph_ties + insert_space + graph_ties)
| (graph_teen + insert_space + (ties_graph | pynini.cross("1", "ten")))
) + pynutil.delete("0s")
graph_with_s |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred") + pynutil.delete("s")
graph_with_s = graph_with_s @ pynini.cdrewrite(
pynini.cross("y", "ies") | pynutil.insert("s"), "", "[EOS]", NEMO_SIGMA
)
graph = graph_ties + insert_space + graph_ties
graph |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred")
thousand_graph = (
graph_digit
+ insert_space
+ pynini.cross("00", "thousand")
+ (pynutil.delete("0") | insert_space + graph_digit)
)
thousand_graph |= (
graph_digit
+ insert_space
+ pynini.cross("000", "thousand")
+ pynini.closure(pynutil.delete(" "), 0, 1)
+ pynini.accep("s")
)
graph |= graph_with_s
if deterministic:
graph = plurals._priority_union(thousand_graph, graph, NEMO_SIGMA)
else:
graph |= thousand_graph
return graph.optimize()
def _get_two_digit_year_with_s_graph():
# to handle '70s -> seventies
graph = (
pynini.closure(pynutil.delete("'"), 0, 1)
+ pynini.compose(
ties_graph + pynutil.delete("0s"), pynini.cdrewrite(pynini.cross("y", "ies"), "", "[EOS]", NEMO_SIGMA)
)
).optimize()
return graph
def _get_year_graph(cardinal_graph, deterministic: bool = True):
"""
Transducer for year, only from 1000 - 2999 e.g.
1290 -> twelve nineteen
2000 - 2009 will be verbalized as two thousand.
Transducer for 3 digit year, e.g. 123-> one twenty three
Transducer for year with suffix
123 A.D., 4200 B.C
"""
graph = get_four_digit_year_graph(deterministic)
graph = (pynini.union("1", "2") + (NEMO_DIGIT ** 3) + pynini.closure(pynini.cross(" s", "s") | "s", 0, 1)) @ graph
graph |= _get_two_digit_year_with_s_graph()
three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph
year_with_suffix = (
(get_four_digit_year_graph(deterministic=True) | three_digit_year) + delete_space + insert_space + year_suffix
)
graph |= year_with_suffix
return graph.optimize()
def _get_two_digit_year(cardinal_graph, single_digits_graph):
wo_digit_year = NEMO_DIGIT ** (2) @ plurals._priority_union(cardinal_graph, single_digits_graph, NEMO_SIGMA)
return wo_digit_year
class DateFst(GraphFst):
"""
Finite state transducer for classifying date, e.g.
jan. 5, 2012 -> date { month: "january" day: "five" year: "twenty twelve" preserve_order: true }
jan. 5 -> date { month: "january" day: "five" preserve_order: true }
5 january 2012 -> date { day: "five" month: "january" year: "twenty twelve" preserve_order: true }
2012-01-05 -> date { year: "twenty twelve" month: "january" day: "five" }
2012.01.05 -> date { year: "twenty twelve" month: "january" day: "five" }
2012/01/05 -> date { year: "twenty twelve" month: "january" day: "five" }
2012 -> date { year: "twenty twelve" }
Args:
cardinal: CardinalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal: GraphFst, deterministic: bool, lm: bool = False):
super().__init__(name="date", kind="classify", deterministic=deterministic)
# january
month_graph = pynini.string_file(get_abs_path("data/date/month_name.tsv")).optimize()
# January, JANUARY
month_graph |= pynini.compose(TO_LOWER + pynini.closure(NEMO_CHAR), month_graph) | pynini.compose(
TO_LOWER ** (2, ...), month_graph
)
# jan
month_abbr_graph = pynini.string_file(get_abs_path("data/date/month_abbr.tsv")).optimize()
# jan, Jan, JAN
month_abbr_graph = (
month_abbr_graph
| pynini.compose(TO_LOWER + pynini.closure(NEMO_LOWER, 1), month_abbr_graph).optimize()
| pynini.compose(TO_LOWER ** (2, ...), month_abbr_graph).optimize()
) + pynini.closure(pynutil.delete("."), 0, 1)
month_graph |= month_abbr_graph.optimize()
month_numbers_labels = pynini.string_file(get_abs_path("data/date/month_number.tsv")).optimize()
cardinal_graph = cardinal.graph_hundred_component_at_least_one_none_zero_digit
year_graph = _get_year_graph(cardinal_graph=cardinal_graph, deterministic=deterministic)
# three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph
# year_graph |= three_digit_year
month_graph = pynutil.insert("month: \"") + month_graph + pynutil.insert("\"")
month_numbers_graph = pynutil.insert("month: \"") + month_numbers_labels + pynutil.insert("\"")
endings = ["rd", "th", "st", "nd"]
endings += [x.upper() for x in endings]
endings = pynini.union(*endings)
day_graph = (
pynutil.insert("day: \"")
+ pynini.closure(pynutil.delete("the "), 0, 1)
+ (
((pynini.union("1", "2") + NEMO_DIGIT) | NEMO_DIGIT | (pynini.accep("3") + pynini.union("0", "1")))
+ pynini.closure(pynutil.delete(endings), 0, 1)
)
@ cardinal_graph
+ pynutil.insert("\"")
)
two_digit_year = _get_two_digit_year(
cardinal_graph=cardinal_graph, single_digits_graph=cardinal.single_digits_graph
)
two_digit_year = pynutil.insert("year: \"") + two_digit_year + pynutil.insert("\"")
# if lm:
# two_digit_year = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (3), two_digit_year)
# year_graph = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (2), year_graph)
# year_graph |= pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (4, ...), year_graph)
graph_year = pynutil.insert(" year: \"") + pynutil.delete(" ") + year_graph + pynutil.insert("\"")
graph_year |= (
pynutil.insert(" year: \"")
+ pynini.accep(",")
+ pynini.closure(pynini.accep(" "), 0, 1)
+ year_graph
+ pynutil.insert("\"")
)
optional_graph_year = pynini.closure(graph_year, 0, 1)
year_graph = pynutil.insert("year: \"") + year_graph + pynutil.insert("\"")
graph_mdy = month_graph + (
(delete_extra_space + day_graph)
| (pynini.accep(" ") + day_graph)
| graph_year
| (delete_extra_space + day_graph + graph_year)
)
graph_mdy |= (
month_graph
+ pynini.cross("-", " ")
+ day_graph
+ pynini.closure(((pynini.cross("-", " ") + NEMO_SIGMA) @ graph_year), 0, 1)
)
for x in ["-", "/", "."]:
delete_sep = pynutil.delete(x)
graph_mdy |= (
month_numbers_graph
+ delete_sep
+ insert_space
+ pynini.closure(pynutil.delete("0"), 0, 1)
+ day_graph
+ delete_sep
+ insert_space
+ (year_graph | two_digit_year)
)
graph_dmy = day_graph + delete_extra_space + month_graph + optional_graph_year
day_ex_month = (NEMO_DIGIT ** 2 - pynini.project(month_numbers_graph, "input")) @ day_graph
for x in ["-", "/", "."]:
delete_sep = pynutil.delete(x)
graph_dmy |= (
day_ex_month
+ delete_sep
+ insert_space
+ month_numbers_graph
+ delete_sep
+ insert_space
+ (year_graph | two_digit_year)
)
graph_ymd = pynini.accep("")
for x in ["-", "/", "."]:
delete_sep = pynutil.delete(x)
graph_ymd |= (
(year_graph | two_digit_year)
+ delete_sep
+ insert_space
+ month_numbers_graph
+ delete_sep
+ insert_space
+ pynini.closure(pynutil.delete("0"), 0, 1)
+ day_graph
)
final_graph = graph_mdy | graph_dmy
if not deterministic or lm:
final_graph += pynini.closure(pynutil.insert(" preserve_order: true"), 0, 1)
m_sep_d = (
month_numbers_graph
+ pynutil.delete(pynini.union("-", "/"))
+ insert_space
+ pynini.closure(pynutil.delete("0"), 0, 1)
+ day_graph
)
final_graph |= m_sep_d
else:
final_graph += pynutil.insert(" preserve_order: true")
final_graph |= graph_ymd | year_graph
if not deterministic or lm:
ymd_to_mdy_graph = None
ymd_to_dmy_graph = None
mdy_to_dmy_graph = None
md_to_dm_graph = None
for month in [x[0] for x in load_labels(get_abs_path("data/date/month_name.tsv"))]:
for day in [x[0] for x in load_labels(get_abs_path("data/date/day.tsv"))]:
ymd_to_mdy_curr = (
pynutil.insert("month: \"" + month + "\" day: \"" + day + "\" ")
+ pynini.accep('year:')
+ NEMO_SIGMA
+ pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"")
)
# YY-MM-DD -> MM-DD-YY
ymd_to_mdy_curr = pynini.compose(graph_ymd, ymd_to_mdy_curr)
ymd_to_mdy_graph = (
ymd_to_mdy_curr
if ymd_to_mdy_graph is None
else pynini.union(ymd_to_mdy_curr, ymd_to_mdy_graph)
)
ymd_to_dmy_curr = (
pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ")
+ pynini.accep('year:')
+ NEMO_SIGMA
+ pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"")
)
# YY-MM-DD -> MM-DD-YY
ymd_to_dmy_curr = pynini.compose(graph_ymd, ymd_to_dmy_curr).optimize()
ymd_to_dmy_graph = (
ymd_to_dmy_curr
if ymd_to_dmy_graph is None
else pynini.union(ymd_to_dmy_curr, ymd_to_dmy_graph)
)
mdy_to_dmy_curr = (
pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ")
+ pynutil.delete("month: \"" + month + "\" day: \"" + day + "\" ")
+ pynini.accep('year:')
+ NEMO_SIGMA
).optimize()
# MM-DD-YY -> verbalize as MM-DD-YY (February fourth 1991) or DD-MM-YY (the fourth of February 1991)
mdy_to_dmy_curr = pynini.compose(graph_mdy, mdy_to_dmy_curr).optimize()
mdy_to_dmy_graph = (
mdy_to_dmy_curr
if mdy_to_dmy_graph is None
else pynini.union(mdy_to_dmy_curr, mdy_to_dmy_graph).optimize()
).optimize()
md_to_dm_curr = pynutil.insert("day: \"" + day + "\" month: \"" + month + "\"") + pynutil.delete(
"month: \"" + month + "\" day: \"" + day + "\""
)
md_to_dm_curr = pynini.compose(m_sep_d, md_to_dm_curr).optimize()
md_to_dm_graph = (
md_to_dm_curr
if md_to_dm_graph is None
else pynini.union(md_to_dm_curr, md_to_dm_graph).optimize()
).optimize()
final_graph |= mdy_to_dmy_graph | md_to_dm_graph | ymd_to_mdy_graph | ymd_to_dmy_graph
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,129 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_SIGMA, TO_UPPER, GraphFst, get_abs_path
from pynini.lib import pynutil
delete_space = pynutil.delete(" ")
quantities = pynini.string_file(get_abs_path("data/number/thousand.tsv"))
quantities_abbr = pynini.string_file(get_abs_path("data/number/quantity_abbr.tsv"))
quantities_abbr |= TO_UPPER @ quantities_abbr
def get_quantity(
decimal: 'pynini.FstLike', cardinal_up_to_hundred: 'pynini.FstLike', include_abbr: bool
) -> 'pynini.FstLike':
"""
Returns FST that transforms either a cardinal or decimal followed by a quantity into a numeral,
e.g. 1 million -> integer_part: "one" quantity: "million"
e.g. 1.5 million -> integer_part: "one" fractional_part: "five" quantity: "million"
Args:
decimal: decimal FST
cardinal_up_to_hundred: cardinal FST
"""
quantity_wo_thousand = pynini.project(quantities, "input") - pynini.union("k", "K", "thousand")
if include_abbr:
quantity_wo_thousand |= pynini.project(quantities_abbr, "input") - pynini.union("k", "K", "thousand")
res = (
pynutil.insert("integer_part: \"")
+ cardinal_up_to_hundred
+ pynutil.insert("\"")
+ pynini.closure(pynutil.delete(" "), 0, 1)
+ pynutil.insert(" quantity: \"")
+ (quantity_wo_thousand @ (quantities | quantities_abbr))
+ pynutil.insert("\"")
)
if include_abbr:
quantity = quantities | quantities_abbr
else:
quantity = quantities
res |= (
decimal
+ pynini.closure(pynutil.delete(" "), 0, 1)
+ pynutil.insert("quantity: \"")
+ quantity
+ pynutil.insert("\"")
)
return res
class DecimalFst(GraphFst):
"""
Finite state transducer for classifying decimal, e.g.
-12.5006 billion -> decimal { negative: "true" integer_part: "12" fractional_part: "five o o six" quantity: "billion" }
1 billion -> decimal { integer_part: "one" quantity: "billion" }
cardinal: CardinalFst
"""
def __init__(self, cardinal: GraphFst, deterministic: bool):
super().__init__(name="decimal", kind="classify", deterministic=deterministic)
cardinal_graph = cardinal.graph_with_and
cardinal_graph_hundred_component_at_least_one_none_zero_digit = (
cardinal.graph_hundred_component_at_least_one_none_zero_digit
)
self.graph = cardinal.single_digits_graph.optimize()
if not deterministic:
self.graph = self.graph | cardinal_graph
point = pynutil.delete(".")
optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
self.graph_fractional = pynutil.insert("fractional_part: \"") + self.graph + pynutil.insert("\"")
self.graph_integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
final_graph_wo_sign = (
pynini.closure(self.graph_integer + pynutil.insert(" "), 0, 1)
+ point
+ pynutil.insert(" ")
+ self.graph_fractional
)
quantity_w_abbr = get_quantity(
final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=True
)
quantity_wo_abbr = get_quantity(
final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=False
)
self.final_graph_wo_negative_w_abbr = final_graph_wo_sign | quantity_w_abbr
self.final_graph_wo_negative = final_graph_wo_sign | quantity_wo_abbr
# reduce options for non_deterministic and allow either "oh" or "zero", but not combination
if not deterministic:
no_oh_zero = pynini.difference(
NEMO_SIGMA,
(NEMO_SIGMA + "oh" + NEMO_SIGMA + "zero" + NEMO_SIGMA)
| (NEMO_SIGMA + "zero" + NEMO_SIGMA + "oh" + NEMO_SIGMA),
).optimize()
no_zero_oh = pynini.difference(
NEMO_SIGMA, NEMO_SIGMA + pynini.accep("zero") + NEMO_SIGMA + pynini.accep("oh") + NEMO_SIGMA
).optimize()
self.final_graph_wo_negative |= pynini.compose(
self.final_graph_wo_negative,
pynini.cdrewrite(
pynini.cross("integer_part: \"zero\"", "integer_part: \"oh\""), NEMO_SIGMA, NEMO_SIGMA, NEMO_SIGMA
),
)
self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_oh_zero).optimize()
self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_zero_oh).optimize()
final_graph = optional_graph_negative + self.final_graph_wo_negative
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,87 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_SIGMA,
GraphFst,
get_abs_path,
insert_space,
)
from pynini.lib import pynutil
class ElectronicFst(GraphFst):
"""
Finite state transducer for classifying electronic: as URLs, email addresses, etc.
e.g. cdf1@abc.edu -> tokens { electronic { username: "cdf1" domain: "abc.edu" } }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="electronic", kind="classify", deterministic=deterministic)
accepted_symbols = pynini.project(pynini.string_file(get_abs_path("data/electronic/symbol.tsv")), "input")
accepted_common_domains = pynini.project(
pynini.string_file(get_abs_path("data/electronic/domain.tsv")), "input"
)
all_accepted_symbols = NEMO_ALPHA + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols)
graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize()
username = (
pynutil.insert("username: \"") + all_accepted_symbols + pynutil.insert("\"") + pynini.cross('@', ' ')
)
domain_graph = all_accepted_symbols + pynini.accep('.') + all_accepted_symbols + NEMO_ALPHA
protocol_symbols = pynini.closure((graph_symbols | pynini.cross(":", "semicolon")) + pynutil.insert(" "))
protocol_start = (pynini.cross("https", "HTTPS ") | pynini.cross("http", "HTTP ")) + (
pynini.accep("://") @ protocol_symbols
)
protocol_file_start = pynini.accep("file") + insert_space + (pynini.accep(":///") @ protocol_symbols)
protocol_end = pynini.cross("www", "WWW ") + pynini.accep(".") @ protocol_symbols
protocol = protocol_file_start | protocol_start | protocol_end | (protocol_start + protocol_end)
domain_graph = (
pynutil.insert("domain: \"")
+ pynini.difference(domain_graph, pynini.project(protocol, "input") + NEMO_SIGMA)
+ pynutil.insert("\"")
)
domain_common_graph = (
pynutil.insert("domain: \"")
+ pynini.difference(
all_accepted_symbols
+ accepted_common_domains
+ pynini.closure(accepted_symbols + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols), 0, 1),
pynini.project(protocol, "input") + NEMO_SIGMA,
)
+ pynutil.insert("\"")
)
protocol = pynutil.insert("protocol: \"") + protocol + pynutil.insert("\"")
# email
graph = username + domain_graph
# abc.com, abc.com/123-sm
graph |= domain_common_graph
# www.abc.com/sdafsdf, or https://www.abc.com/asdfad or www.abc.abc/asdfad
graph |= protocol + pynutil.insert(" ") + domain_graph
final_graph = self.add_tokens(graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,55 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import GraphFst, get_abs_path
from pynini.lib import pynutil
class FractionFst(GraphFst):
"""
Finite state transducer for classifying fraction
"23 4/5" ->
tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } }
"23 4/5th" ->
tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal, deterministic: bool = True):
super().__init__(name="fraction", kind="classify", deterministic=deterministic)
cardinal_graph = cardinal.graph
integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
numerator = (
pynutil.insert("numerator: \"") + cardinal_graph + (pynini.cross("/", "\" ") | pynini.cross(" / ", "\" "))
)
endings = ["rd", "th", "st", "nd"]
endings += [x.upper() for x in endings]
optional_end = pynini.closure(pynini.cross(pynini.union(*endings), ""), 0, 1)
denominator = pynutil.insert("denominator: \"") + cardinal_graph + optional_end + pynutil.insert("\"")
graph = pynini.closure(integer + pynini.accep(" "), 0, 1) + (numerator + denominator)
graph |= pynini.closure(integer + (pynini.accep(" ") | pynutil.insert(" ")), 0, 1) + pynini.compose(
pynini.string_file(get_abs_path("data/number/fraction.tsv")), (numerator + denominator)
)
self.graph = graph
final_graph = self.add_tokens(self.graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,304 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_NON_BREAKING_SPACE,
NEMO_SIGMA,
NEMO_SPACE,
NEMO_UPPER,
SINGULAR_TO_PLURAL,
TO_LOWER,
GraphFst,
convert_space,
delete_space,
delete_zero_or_one_space,
insert_space,
)
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst as OrdinalTagger
from nemo_text_processing.text_normalization.en.taggers.whitelist import get_formats
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as OrdinalVerbalizer
from pynini.examples import plurals
from pynini.lib import pynutil
class MeasureFst(GraphFst):
"""
Finite state transducer for classifying measure, suppletive aware, e.g.
-12kg -> measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" }
1kg -> measure { cardinal { integer: "one" } units: "kilogram" }
.5kg -> measure { decimal { fractional_part: "five" } units: "kilograms" }
Args:
cardinal: CardinalFst
decimal: DecimalFst
fraction: FractionFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal: GraphFst, decimal: GraphFst, fraction: GraphFst, deterministic: bool = True):
super().__init__(name="measure", kind="classify", deterministic=deterministic)
cardinal_graph = cardinal.graph_with_and | self.get_range(cardinal.graph_with_and)
graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv"))
if not deterministic:
graph_unit |= pynini.string_file(get_abs_path("data/measure/unit_alternatives.tsv"))
graph_unit |= pynini.compose(
pynini.closure(TO_LOWER, 1) + (NEMO_ALPHA | TO_LOWER) + pynini.closure(NEMO_ALPHA | TO_LOWER), graph_unit
).optimize()
graph_unit_plural = convert_space(graph_unit @ SINGULAR_TO_PLURAL)
graph_unit = convert_space(graph_unit)
optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
graph_unit2 = (
pynini.cross("/", "per") + delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit
)
optional_graph_unit2 = pynini.closure(
delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit2, 0, 1,
)
unit_plural = (
pynutil.insert("units: \"")
+ (graph_unit_plural + optional_graph_unit2 | graph_unit2)
+ pynutil.insert("\"")
)
unit_singular = (
pynutil.insert("units: \"") + (graph_unit + optional_graph_unit2 | graph_unit2) + pynutil.insert("\"")
)
subgraph_decimal = (
pynutil.insert("decimal { ")
+ optional_graph_negative
+ decimal.final_graph_wo_negative
+ delete_space
+ pynutil.insert(" } ")
+ unit_plural
)
# support radio FM/AM
subgraph_decimal |= (
pynutil.insert("decimal { ")
+ decimal.final_graph_wo_negative
+ delete_space
+ pynutil.insert(" } ")
+ pynutil.insert("units: \"")
+ pynini.union("AM", "FM")
+ pynutil.insert("\"")
)
subgraph_cardinal = (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ ((NEMO_SIGMA - "1") @ cardinal_graph)
+ delete_space
+ pynutil.insert("\"")
+ pynutil.insert(" } ")
+ unit_plural
)
subgraph_cardinal |= (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ pynini.cross("1", "one")
+ delete_space
+ pynutil.insert("\"")
+ pynutil.insert(" } ")
+ unit_singular
)
unit_graph = (
pynutil.insert("cardinal { integer: \"-\" } units: \"")
+ pynini.cross(pynini.union("/", "per"), "per")
+ delete_zero_or_one_space
+ pynutil.insert(NEMO_NON_BREAKING_SPACE)
+ graph_unit
+ pynutil.insert("\" preserve_order: true")
)
decimal_dash_alpha = (
pynutil.insert("decimal { ")
+ decimal.final_graph_wo_negative
+ pynini.cross('-', '')
+ pynutil.insert(" } units: \"")
+ pynini.closure(NEMO_ALPHA, 1)
+ pynutil.insert("\"")
)
decimal_times = (
pynutil.insert("decimal { ")
+ decimal.final_graph_wo_negative
+ pynutil.insert(" } units: \"")
+ pynini.cross(pynini.union('x', "X"), 'x')
+ pynutil.insert("\"")
)
alpha_dash_decimal = (
pynutil.insert("units: \"")
+ pynini.closure(NEMO_ALPHA, 1)
+ pynini.accep('-')
+ pynutil.insert("\"")
+ pynutil.insert(" decimal { ")
+ decimal.final_graph_wo_negative
+ pynutil.insert(" } preserve_order: true")
)
subgraph_fraction = (
pynutil.insert("fraction { ") + fraction.graph + delete_space + pynutil.insert(" } ") + unit_plural
)
address = self.get_address_graph(cardinal)
address = (
pynutil.insert("units: \"address\" cardinal { integer: \"")
+ address
+ pynutil.insert("\" } preserve_order: true")
)
math_operations = pynini.string_file(get_abs_path("data/measure/math_operation.tsv"))
delimiter = pynini.accep(" ") | pynutil.insert(" ")
math = (
(cardinal_graph | NEMO_ALPHA)
+ delimiter
+ math_operations
+ (delimiter | NEMO_ALPHA)
+ cardinal_graph
+ delimiter
+ pynini.cross("=", "equals")
+ delimiter
+ (cardinal_graph | NEMO_ALPHA)
)
math |= (
(cardinal_graph | NEMO_ALPHA)
+ delimiter
+ pynini.cross("=", "equals")
+ delimiter
+ (cardinal_graph | NEMO_ALPHA)
+ delimiter
+ math_operations
+ delimiter
+ cardinal_graph
)
math = (
pynutil.insert("units: \"math\" cardinal { integer: \"")
+ math
+ pynutil.insert("\" } preserve_order: true")
)
final_graph = (
subgraph_decimal
| subgraph_cardinal
| unit_graph
| decimal_dash_alpha
| decimal_times
| alpha_dash_decimal
| subgraph_fraction
| address
| math
)
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()
def get_range(self, cardinal: GraphFst):
"""
Returns range forms for measure tagger, e.g. 2-3, 2x3, 2*2
Args:
cardinal: cardinal GraphFst
"""
range_graph = cardinal + pynini.cross(pynini.union("-", " - "), " to ") + cardinal
for x in [" x ", "x"]:
range_graph |= cardinal + pynini.cross(x, " by ") + cardinal
if not self.deterministic:
range_graph |= cardinal + pynini.cross(x, " times ") + cardinal
for x in ["*", " * "]:
range_graph |= cardinal + pynini.cross(x, " times ") + cardinal
return range_graph.optimize()
def get_address_graph(self, cardinal):
"""
Finite state transducer for classifying serial.
The serial is a combination of digits, letters and dashes, e.g.:
2788 San Tomas Expy, Santa Clara, CA 95051 ->
units: "address" cardinal
{ integer: "two seven eight eight San Tomas Expressway Santa Clara California nine five zero five one" }
preserve_order: true
"""
ordinal_verbalizer = OrdinalVerbalizer().graph
ordinal_tagger = OrdinalTagger(cardinal=cardinal).graph
ordinal_num = pynini.compose(
pynutil.insert("integer: \"") + ordinal_tagger + pynutil.insert("\""), ordinal_verbalizer
)
address_num = NEMO_DIGIT ** (1, 2) @ cardinal.graph_hundred_component_at_least_one_none_zero_digit
address_num += insert_space + NEMO_DIGIT ** 2 @ (
pynini.closure(pynini.cross("0", "zero "), 0, 1)
+ cardinal.graph_hundred_component_at_least_one_none_zero_digit
)
# to handle the rest of the numbers
address_num = pynini.compose(NEMO_DIGIT ** (3, 4), address_num)
address_num = plurals._priority_union(address_num, cardinal.graph, NEMO_SIGMA)
direction = (
pynini.cross("E", "East")
| pynini.cross("S", "South")
| pynini.cross("W", "West")
| pynini.cross("N", "North")
) + pynini.closure(pynutil.delete("."), 0, 1)
direction = pynini.closure(pynini.accep(NEMO_SPACE) + direction, 0, 1)
address_words = get_formats(get_abs_path("data/address/address_word.tsv"))
address_words = (
pynini.accep(NEMO_SPACE)
+ (pynini.closure(ordinal_num, 0, 1) | NEMO_UPPER + pynini.closure(NEMO_ALPHA, 1))
+ NEMO_SPACE
+ pynini.closure(NEMO_UPPER + pynini.closure(NEMO_ALPHA) + NEMO_SPACE)
+ address_words
)
city = pynini.closure(NEMO_ALPHA | pynini.accep(NEMO_SPACE), 1)
city = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + city, 0, 1)
states = load_labels(get_abs_path("data/address/state.tsv"))
additional_options = []
for x, y in states:
additional_options.append((x, f"{y[0]}.{y[1:]}"))
states.extend(additional_options)
state_graph = pynini.string_map(states)
state = pynini.invert(state_graph)
state = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + state, 0, 1)
zip_code = pynini.compose(NEMO_DIGIT ** 5, cardinal.single_digits_graph)
zip_code = pynini.closure(pynini.closure(pynini.accep(","), 0, 1) + pynini.accep(NEMO_SPACE) + zip_code, 0, 1,)
address = address_num + direction + address_words + pynini.closure(city + state + zip_code, 0, 1)
address |= address_num + direction + address_words + pynini.closure(pynini.cross(".", ""), 0, 1)
return address

View File

@@ -0,0 +1,192 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_SIGMA,
SINGULAR_TO_PLURAL,
GraphFst,
convert_space,
insert_space,
)
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from pynini.lib import pynutil
min_singular = pynini.string_file(get_abs_path("data/money/currency_minor_singular.tsv"))
min_plural = pynini.string_file(get_abs_path("data/money/currency_minor_plural.tsv"))
maj_singular = pynini.string_file((get_abs_path("data/money/currency_major.tsv")))
class MoneyFst(GraphFst):
"""
Finite state transducer for classifying money, suppletive aware, e.g.
$12.05 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true }
$12.0500 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true }
$1 -> money { currency_maj: "dollar" integer_part: "one" }
$1.00 -> money { currency_maj: "dollar" integer_part: "one" }
$0.05 -> money { fractional_part: "five" currency_min: "cents" preserve_order: true }
$1 million -> money { currency_maj: "dollars" integer_part: "one" quantity: "million" }
$1.2 million -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two" quantity: "million" }
$1.2320 -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two three two" }
Args:
cardinal: CardinalFst
decimal: DecimalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal: GraphFst, decimal: GraphFst, deterministic: bool = True):
super().__init__(name="money", kind="classify", deterministic=deterministic)
cardinal_graph = cardinal.graph_with_and
graph_decimal_final = decimal.final_graph_wo_negative_w_abbr
maj_singular_labels = load_labels(get_abs_path("data/money/currency_major.tsv"))
maj_unit_plural = convert_space(maj_singular @ SINGULAR_TO_PLURAL)
maj_unit_singular = convert_space(maj_singular)
graph_maj_singular = pynutil.insert("currency_maj: \"") + maj_unit_singular + pynutil.insert("\"")
graph_maj_plural = pynutil.insert("currency_maj: \"") + maj_unit_plural + pynutil.insert("\"")
optional_delete_fractional_zeros = pynini.closure(
pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), 0, 1
)
graph_integer_one = pynutil.insert("integer_part: \"") + pynini.cross("1", "one") + pynutil.insert("\"")
# only for decimals where third decimal after comma is non-zero or with quantity
decimal_delete_last_zeros = (
pynini.closure(NEMO_DIGIT | pynutil.delete(","))
+ pynini.accep(".")
+ pynini.closure(NEMO_DIGIT, 2)
+ (NEMO_DIGIT - "0")
+ pynini.closure(pynutil.delete("0"))
)
decimal_with_quantity = NEMO_SIGMA + NEMO_ALPHA
graph_decimal = (
graph_maj_plural + insert_space + (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final
)
graph_integer = (
pynutil.insert("integer_part: \"") + ((NEMO_SIGMA - "1") @ cardinal_graph) + pynutil.insert("\"")
)
graph_integer_only = graph_maj_singular + insert_space + graph_integer_one
graph_integer_only |= graph_maj_plural + insert_space + graph_integer
final_graph = (graph_integer_only + optional_delete_fractional_zeros) | graph_decimal
# remove trailing zeros of non zero number in the first 2 digits and fill up to 2 digits
# e.g. 2000 -> 20, 0200->02, 01 -> 01, 10 -> 10
# not accepted: 002, 00, 0,
two_digits_fractional_part = (
pynini.closure(NEMO_DIGIT) + (NEMO_DIGIT - "0") + pynini.closure(pynutil.delete("0"))
) @ (
(pynutil.delete("0") + (NEMO_DIGIT - "0"))
| ((NEMO_DIGIT - "0") + pynutil.insert("0"))
| ((NEMO_DIGIT - "0") + NEMO_DIGIT)
)
graph_min_singular = pynutil.insert(" currency_min: \"") + min_singular + pynutil.insert("\"")
graph_min_plural = pynutil.insert(" currency_min: \"") + min_plural + pynutil.insert("\"")
# format ** dollars ** cent
decimal_graph_with_minor = None
integer_graph_reordered = None
decimal_default_reordered = None
for curr_symbol, _ in maj_singular_labels:
preserve_order = pynutil.insert(" preserve_order: true")
integer_plus_maj = graph_integer + insert_space + pynutil.insert(curr_symbol) @ graph_maj_plural
integer_plus_maj |= graph_integer_one + insert_space + pynutil.insert(curr_symbol) @ graph_maj_singular
integer_plus_maj_with_comma = pynini.compose(
NEMO_DIGIT - "0" + pynini.closure(NEMO_DIGIT | pynutil.delete(",")), integer_plus_maj
)
integer_plus_maj = pynini.compose(pynini.closure(NEMO_DIGIT) - "0", integer_plus_maj)
integer_plus_maj |= integer_plus_maj_with_comma
graph_fractional_one = two_digits_fractional_part @ pynini.cross("1", "one")
graph_fractional_one = pynutil.insert("fractional_part: \"") + graph_fractional_one + pynutil.insert("\"")
graph_fractional = (
two_digits_fractional_part
@ (pynini.closure(NEMO_DIGIT, 1, 2) - "1")
@ cardinal.graph_hundred_component_at_least_one_none_zero_digit
)
graph_fractional = pynutil.insert("fractional_part: \"") + graph_fractional + pynutil.insert("\"")
fractional_plus_min = graph_fractional + insert_space + pynutil.insert(curr_symbol) @ graph_min_plural
fractional_plus_min |= (
graph_fractional_one + insert_space + pynutil.insert(curr_symbol) @ graph_min_singular
)
decimal_graph_with_minor_curr = integer_plus_maj + pynini.cross(".", " ") + fractional_plus_min
if not deterministic:
decimal_graph_with_minor_curr |= pynutil.add_weight(
integer_plus_maj
+ pynini.cross(".", " ")
+ pynutil.insert("fractional_part: \"")
+ two_digits_fractional_part @ cardinal.graph_hundred_component_at_least_one_none_zero_digit
+ pynutil.insert("\""),
weight=0.0001,
)
default_fraction_graph = (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final
decimal_graph_with_minor_curr |= (
pynini.closure(pynutil.delete("0"), 0, 1) + pynutil.delete(".") + fractional_plus_min
)
decimal_graph_with_minor_curr = (
pynutil.delete(curr_symbol) + decimal_graph_with_minor_curr + preserve_order
)
decimal_graph_with_minor = (
decimal_graph_with_minor_curr
if decimal_graph_with_minor is None
else pynini.union(decimal_graph_with_minor, decimal_graph_with_minor_curr).optimize()
)
if not deterministic:
integer_graph_reordered_curr = (
pynutil.delete(curr_symbol) + integer_plus_maj + preserve_order
).optimize()
integer_graph_reordered = (
integer_graph_reordered_curr
if integer_graph_reordered is None
else pynini.union(integer_graph_reordered, integer_graph_reordered_curr).optimize()
)
decimal_default_reordered_curr = (
pynutil.delete(curr_symbol)
+ default_fraction_graph
+ insert_space
+ pynutil.insert(curr_symbol) @ graph_maj_plural
)
decimal_default_reordered = (
decimal_default_reordered_curr
if decimal_default_reordered is None
else pynini.union(decimal_default_reordered, decimal_default_reordered_curr)
).optimize()
# weight for SH
final_graph |= pynutil.add_weight(decimal_graph_with_minor, -0.0001)
if not deterministic:
final_graph |= integer_graph_reordered | decimal_default_reordered
# to handle "$2.00" cases
final_graph |= pynini.compose(
NEMO_SIGMA + pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), integer_graph_reordered
)
final_graph = self.add_tokens(final_graph.optimize())
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,61 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst
from pynini.lib import pynutil
class OrdinalFst(GraphFst):
"""
Finite state transducer for classifying ordinal, e.g.
13th -> ordinal { integer: "thirteen" }
Args:
cardinal: CardinalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal: GraphFst, deterministic: bool = True):
super().__init__(name="ordinal", kind="classify", deterministic=deterministic)
cardinal_graph = cardinal.graph
cardinal_format = pynini.closure(NEMO_DIGIT | pynini.accep(","))
st_format = (
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
+ pynini.accep("1")
+ pynutil.delete(pynini.union("st", "ST"))
)
nd_format = (
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
+ pynini.accep("2")
+ pynutil.delete(pynini.union("nd", "ND"))
)
rd_format = (
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
+ pynini.accep("3")
+ pynutil.delete(pynini.union("rd", "RD"))
)
th_format = pynini.closure(
(NEMO_DIGIT - "1" - "2" - "3")
| (cardinal_format + "1" + NEMO_DIGIT)
| (cardinal_format + (NEMO_DIGIT - "1") + (NEMO_DIGIT - "1" - "2" - "3")),
1,
) + pynutil.delete(pynini.union("th", "TH"))
self.graph = (st_format | nd_format | rd_format | th_format) @ cardinal_graph
final_graph = pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,65 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unicodedata import category
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_SPACE, NEMO_SIGMA, GraphFst
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from pynini.examples import plurals
from pynini.lib import pynutil
class PunctuationFst(GraphFst):
"""
Finite state transducer for classifying punctuation
e.g. a, -> tokens { name: "a" } tokens { name: "," }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="punctuation", kind="classify", deterministic=deterministic)
s = "!#%&\'()*+,-./:;<=>?@^_`{|}~\""
punct_symbols_to_exclude = ["[", "]"]
punct_unicode = [
chr(i)
for i in range(sys.maxunicode)
if category(chr(i)).startswith("P") and chr(i) not in punct_symbols_to_exclude
]
whitelist_symbols = load_labels(get_abs_path("data/whitelist/symbol.tsv"))
whitelist_symbols = [x[0] for x in whitelist_symbols]
self.punct_marks = [p for p in punct_unicode + list(s) if p not in whitelist_symbols]
punct = pynini.union(*self.punct_marks)
punct = pynini.closure(punct, 1)
emphasis = (
pynini.accep("<")
+ (
(pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1) + pynini.closure(pynini.accep("/"), 0, 1))
| (pynini.accep("/") + pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1))
)
+ pynini.accep(">")
)
punct = plurals._priority_union(emphasis, punct, NEMO_SIGMA)
self.graph = punct
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()

View File

@@ -0,0 +1,102 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst, convert_space
from pynini.lib import pynutil
class RangeFst(GraphFst):
"""
This class is a composite class of two other class instances
Args:
time: composed tagger and verbalizer
date: composed tagger and verbalizer
cardinal: tagger
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
lm: whether to use for hybrid LM
"""
def __init__(
self, time: GraphFst, date: GraphFst, cardinal: GraphFst, deterministic: bool = True, lm: bool = False,
):
super().__init__(name="range", kind="classify", deterministic=deterministic)
delete_space = pynini.closure(pynutil.delete(" "), 0, 1)
approx = pynini.cross("~", "approximately")
# TIME
time_graph = time + delete_space + pynini.cross("-", " to ") + delete_space + time
self.graph = time_graph | (approx + time)
cardinal = cardinal.graph_with_and
# YEAR
date_year_four_digit = (NEMO_DIGIT ** 4 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
date_year_two_digit = (NEMO_DIGIT ** 2 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
year_to_year_graph = (
date_year_four_digit
+ delete_space
+ pynini.cross("-", " to ")
+ delete_space
+ (date_year_four_digit | date_year_two_digit | (NEMO_DIGIT ** 2 @ cardinal))
)
mid_year_graph = pynini.accep("mid") + pynini.cross("-", " ") + (date_year_four_digit | date_year_two_digit)
self.graph |= year_to_year_graph
self.graph |= mid_year_graph
# ADDITION
range_graph = cardinal + pynini.closure(pynini.cross("+", " plus ") + cardinal, 1)
range_graph |= cardinal + pynini.closure(pynini.cross(" + ", " plus ") + cardinal, 1)
range_graph |= approx + cardinal
range_graph |= cardinal + (pynini.cross("...", " ... ") | pynini.accep(" ... ")) + cardinal
if not deterministic or lm:
# cardinal ----
cardinal_to_cardinal_graph = (
cardinal + delete_space + pynini.cross("-", pynini.union(" to ", " minus ")) + delete_space + cardinal
)
range_graph |= cardinal_to_cardinal_graph | (
cardinal + delete_space + pynini.cross(":", " to ") + delete_space + cardinal
)
# MULTIPLY
for x in [" x ", "x"]:
range_graph |= cardinal + pynini.closure(
pynini.cross(x, pynini.union(" by ", " times ")) + cardinal, 1
)
for x in ["*", " * "]:
range_graph |= cardinal + pynini.closure(pynini.cross(x, " times ") + cardinal, 1)
# supports "No. 12" -> "Number 12"
range_graph |= (
(pynini.cross(pynini.union("NO", "No"), "Number") | pynini.cross("no", "number"))
+ pynini.closure(pynini.union(". ", " "), 0, 1)
+ cardinal
)
for x in ["/", " / "]:
range_graph |= cardinal + pynini.closure(pynini.cross(x, " divided by ") + cardinal, 1)
self.graph |= range_graph
self.graph = self.graph.optimize()
graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"")
self.fst = graph.optimize()

View File

@@ -0,0 +1,114 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_ALPHA, NEMO_SIGMA, GraphFst
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from pynini.lib import pynutil
class RomanFst(GraphFst):
"""
Finite state transducer for classifying roman numbers:
e.g. "IV" -> tokens { roman { integer: "four" } }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True, lm: bool = False):
super().__init__(name="roman", kind="classify", deterministic=deterministic)
roman_dict = load_labels(get_abs_path("data/roman/roman_to_spoken.tsv"))
default_graph = pynini.string_map(roman_dict).optimize()
default_graph = pynutil.insert("integer: \"") + default_graph + pynutil.insert("\"")
ordinal_limit = 19
if deterministic:
# exclude "I"
start_idx = 1
else:
start_idx = 0
graph_teens = pynini.string_map([x[0] for x in roman_dict[start_idx:ordinal_limit]]).optimize()
# roman numerals up to ordinal_limit with a preceding name are converted to ordinal form
names = get_names()
graph = (
pynutil.insert("key_the_ordinal: \"")
+ names
+ pynutil.insert("\"")
+ pynini.accep(" ")
+ graph_teens @ default_graph
).optimize()
# single symbol roman numerals with preceding key words (multiple formats) are converted to cardinal form
key_words = []
for k_word in load_labels(get_abs_path("data/roman/key_word.tsv")):
key_words.append(k_word)
key_words.append([k_word[0][0].upper() + k_word[0][1:]])
key_words.append([k_word[0].upper()])
key_words = pynini.string_map(key_words).optimize()
graph |= (
pynutil.insert("key_cardinal: \"") + key_words + pynutil.insert("\"") + pynini.accep(" ") + default_graph
).optimize()
if deterministic or lm:
# two digit roman numerals up to 49
roman_to_cardinal = pynini.compose(
pynini.closure(NEMO_ALPHA, 2),
(
pynutil.insert("default_cardinal: \"default\" ")
+ (pynini.string_map([x[0] for x in roman_dict[:50]]).optimize()) @ default_graph
),
)
graph |= roman_to_cardinal
elif not lm:
# two or more digit roman numerals
roman_to_cardinal = pynini.compose(
pynini.difference(NEMO_SIGMA, "I"),
(
pynutil.insert("default_cardinal: \"default\" integer: \"")
+ pynini.string_map(roman_dict).optimize()
+ pynutil.insert("\"")
),
).optimize()
graph |= roman_to_cardinal
# convert three digit roman or up with suffix to ordinal
roman_to_ordinal = pynini.compose(
pynini.closure(NEMO_ALPHA, 3),
(pynutil.insert("default_ordinal: \"default\" ") + graph_teens @ default_graph + pynutil.delete("th")),
)
graph |= roman_to_ordinal
graph = self.add_tokens(graph.optimize())
self.fst = graph.optimize()
def get_names():
"""
Returns the graph that matched common male and female names.
"""
male_labels = load_labels(get_abs_path("data/roman/male.tsv"))
female_labels = load_labels(get_abs_path("data/roman/female.tsv"))
male_labels.extend([[x[0].upper()] for x in male_labels])
female_labels.extend([[x[0].upper()] for x in female_labels])
names = pynini.string_map(male_labels).optimize()
names |= pynini.string_map(female_labels).optimize()
return names

View File

@@ -0,0 +1,136 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_NOT_SPACE,
NEMO_SIGMA,
GraphFst,
convert_space,
)
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from pynini.examples import plurals
from pynini.lib import pynutil
class SerialFst(GraphFst):
"""
This class is a composite class of two other class instances
Args:
time: composed tagger and verbalizer
date: composed tagger and verbalizer
cardinal: tagger
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
lm: whether to use for hybrid LM
"""
def __init__(self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = True, lm: bool = False):
super().__init__(name="integer", kind="classify", deterministic=deterministic)
"""
Finite state transducer for classifying serial (handles only cases without delimiters,
values with delimiters are handled by default).
The serial is a combination of digits, letters and dashes, e.g.:
c325b -> tokens { cardinal { integer: "c three two five b" } }
"""
num_graph = pynini.compose(NEMO_DIGIT ** (6, ...), cardinal.single_digits_graph).optimize()
num_graph |= pynini.compose(NEMO_DIGIT ** (1, 5), cardinal.graph).optimize()
# to handle numbers starting with zero
num_graph |= pynini.compose(
pynini.accep("0") + pynini.closure(NEMO_DIGIT), cardinal.single_digits_graph
).optimize()
# TODO: "#" doesn't work from the file
symbols_graph = pynini.string_file(get_abs_path("data/whitelist/symbol.tsv")).optimize() | pynini.cross(
"#", "hash"
)
num_graph |= symbols_graph
if not self.deterministic and not lm:
num_graph |= cardinal.single_digits_graph
# also allow double digits to be pronounced as integer in serial number
num_graph |= pynutil.add_weight(
NEMO_DIGIT ** 2 @ cardinal.graph_hundred_component_at_least_one_none_zero_digit, weight=0.0001
)
# add space between letter and digit/symbol
symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))]
symbols = pynini.union(*symbols)
digit_symbol = NEMO_DIGIT | symbols
graph_with_space = pynini.compose(
pynini.cdrewrite(pynutil.insert(" "), NEMO_ALPHA | symbols, digit_symbol, NEMO_SIGMA),
pynini.cdrewrite(pynutil.insert(" "), digit_symbol, NEMO_ALPHA | symbols, NEMO_SIGMA),
)
# serial graph with delimiter
delimiter = pynini.accep("-") | pynini.accep("/") | pynini.accep(" ")
if not deterministic:
delimiter |= pynini.cross("-", " dash ") | pynini.cross("/", " slash ")
alphas = pynini.closure(NEMO_ALPHA, 1)
letter_num = alphas + delimiter + num_graph
num_letter = pynini.closure(num_graph + delimiter, 1) + alphas
next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph))
next_alpha_or_num |= pynini.closure(
delimiter
+ num_graph
+ plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), NEMO_SIGMA).optimize()
+ alphas
)
serial_graph = letter_num + next_alpha_or_num
serial_graph |= num_letter + next_alpha_or_num
# numbers only with 2+ delimiters
serial_graph |= (
num_graph + delimiter + num_graph + delimiter + num_graph + pynini.closure(delimiter + num_graph)
)
# 2+ symbols
serial_graph |= pynini.compose(NEMO_SIGMA + symbols + NEMO_SIGMA, num_graph + delimiter + num_graph)
# exclude ordinal numbers from serial options
serial_graph = pynini.compose(
pynini.difference(NEMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph
).optimize()
serial_graph = pynutil.add_weight(serial_graph, 0.0001)
serial_graph |= (
pynini.closure(NEMO_NOT_SPACE, 1)
+ (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize()
)
# at least one serial graph with alpha numeric value and optional additional serial/num/alpha values
serial_graph = (
pynini.closure((serial_graph | num_graph | alphas) + delimiter)
+ serial_graph
+ pynini.closure(delimiter + (serial_graph | num_graph | alphas))
)
serial_graph |= pynini.compose(graph_with_space, serial_graph.optimize()).optimize()
serial_graph = pynini.compose(pynini.closure(NEMO_NOT_SPACE, 2), serial_graph).optimize()
# this is not to verbolize "/" as "slash" in cases like "import/export"
serial_graph = pynini.compose(
pynini.difference(
NEMO_SIGMA, pynini.closure(NEMO_ALPHA, 1) + pynini.accep("/") + pynini.closure(NEMO_ALPHA, 1)
),
serial_graph,
)
self.graph = serial_graph.optimize()
graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"")
self.fst = graph.optimize()

View File

@@ -0,0 +1,133 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_SIGMA,
GraphFst,
delete_extra_space,
delete_space,
insert_space,
plurals,
)
from nemo_text_processing.text_normalization.en.utils import get_abs_path
from pynini.lib import pynutil
class TelephoneFst(GraphFst):
"""
Finite state transducer for classifying telephone, and IP, and SSN which includes country code, number part and extension
country code optional: +***
number part: ***-***-****, or (***) ***-****
extension optional: 1-9999
E.g
+1 123-123-5678-1 -> telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" }
1-800-GO-U-HAUL -> telephone { country_code: "one" number_part: "one, eight hundred GO U HAUL" }
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="telephone", kind="classify", deterministic=deterministic)
add_separator = pynutil.insert(", ") # between components
zero = pynini.cross("0", "zero")
if not deterministic:
zero |= pynini.cross("0", pynini.union("o", "oh"))
digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize() | zero
telephone_prompts = pynini.string_file(get_abs_path("data/telephone/telephone_prompt.tsv"))
country_code = (
pynini.closure(telephone_prompts + delete_extra_space, 0, 1)
+ pynini.closure(pynini.cross("+", "plus "), 0, 1)
+ pynini.closure(digit + insert_space, 0, 2)
+ digit
+ pynutil.insert(",")
)
country_code |= telephone_prompts
country_code = pynutil.insert("country_code: \"") + country_code + pynutil.insert("\"")
country_code = country_code + pynini.closure(pynutil.delete("-"), 0, 1) + delete_space + insert_space
area_part_default = pynini.closure(digit + insert_space, 2, 2) + digit
area_part = pynini.cross("800", "eight hundred") | pynini.compose(
pynini.difference(NEMO_SIGMA, "800"), area_part_default
)
area_part = (
(area_part + (pynutil.delete("-") | pynutil.delete(".")))
| (
pynutil.delete("(")
+ area_part
+ ((pynutil.delete(")") + pynini.closure(pynutil.delete(" "), 0, 1)) | pynutil.delete(")-"))
)
) + add_separator
del_separator = pynini.closure(pynini.union("-", " ", "."), 0, 1)
number_length = ((NEMO_DIGIT + del_separator) | (NEMO_ALPHA + del_separator)) ** 7
number_words = pynini.closure(
(NEMO_DIGIT @ digit) + (insert_space | (pynini.cross("-", ', ')))
| NEMO_ALPHA
| (NEMO_ALPHA + pynini.cross("-", ' '))
)
number_words |= pynini.closure(
(NEMO_DIGIT @ digit) + (insert_space | (pynini.cross(".", ', ')))
| NEMO_ALPHA
| (NEMO_ALPHA + pynini.cross(".", ' '))
)
number_words = pynini.compose(number_length, number_words)
number_part = area_part + number_words
number_part = pynutil.insert("number_part: \"") + number_part + pynutil.insert("\"")
extension = (
pynutil.insert("extension: \"") + pynini.closure(digit + insert_space, 0, 3) + digit + pynutil.insert("\"")
)
extension = pynini.closure(insert_space + extension, 0, 1)
graph = plurals._priority_union(country_code + number_part, number_part, NEMO_SIGMA).optimize()
graph = plurals._priority_union(country_code + number_part + extension, graph, NEMO_SIGMA).optimize()
graph = plurals._priority_union(number_part + extension, graph, NEMO_SIGMA).optimize()
# ip
ip_prompts = pynini.string_file(get_abs_path("data/telephone/ip_prompt.tsv"))
digit_to_str_graph = digit + pynini.closure(pynutil.insert(" ") + digit, 0, 2)
ip_graph = digit_to_str_graph + (pynini.cross(".", " dot ") + digit_to_str_graph) ** 3
graph |= (
pynini.closure(
pynutil.insert("country_code: \"") + ip_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1
)
+ pynutil.insert("number_part: \"")
+ ip_graph.optimize()
+ pynutil.insert("\"")
)
# ssn
ssn_prompts = pynini.string_file(get_abs_path("data/telephone/ssn_prompt.tsv"))
three_digit_part = digit + (pynutil.insert(" ") + digit) ** 2
two_digit_part = digit + pynutil.insert(" ") + digit
four_digit_part = digit + (pynutil.insert(" ") + digit) ** 3
ssn_separator = pynini.cross("-", ", ")
ssn_graph = three_digit_part + ssn_separator + two_digit_part + ssn_separator + four_digit_part
graph |= (
pynini.closure(
pynutil.insert("country_code: \"") + ssn_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1
)
+ pynutil.insert("number_part: \"")
+ ssn_graph.optimize()
+ pynutil.insert("\"")
)
final_graph = self.add_tokens(graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_DIGIT,
GraphFst,
convert_space,
delete_space,
insert_space,
)
from nemo_text_processing.text_normalization.en.utils import (
augment_labels_with_punct_at_end,
get_abs_path,
load_labels,
)
from pynini.lib import pynutil
class TimeFst(GraphFst):
"""
Finite state transducer for classifying time, e.g.
12:30 a.m. est -> time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" }
2.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" }
02.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" }
2.00 a.m. -> time { hours: "two" suffix: "a m" }
2 a.m. -> time { hours: "two" suffix: "a m" }
02:00 -> time { hours: "two" }
2:00 -> time { hours: "two" }
10:00:05 a.m. -> time { hours: "ten" minutes: "zero" seconds: "five" suffix: "a m" }
Args:
cardinal: CardinalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal: GraphFst, deterministic: bool = True):
super().__init__(name="time", kind="classify", deterministic=deterministic)
suffix_labels = load_labels(get_abs_path("data/time/suffix.tsv"))
suffix_labels.extend(augment_labels_with_punct_at_end(suffix_labels))
suffix_graph = pynini.string_map(suffix_labels)
time_zone_graph = pynini.string_file(get_abs_path("data/time/zone.tsv"))
# only used for < 1000 thousand -> 0 weight
cardinal = cardinal.graph
labels_hour = [str(x) for x in range(0, 24)]
labels_minute_single = [str(x) for x in range(1, 10)]
labels_minute_double = [str(x) for x in range(10, 60)]
delete_leading_zero_to_double_digit = (NEMO_DIGIT + NEMO_DIGIT) | (
pynini.closure(pynutil.delete("0"), 0, 1) + NEMO_DIGIT
)
graph_hour = delete_leading_zero_to_double_digit @ pynini.union(*labels_hour) @ cardinal
graph_minute_single = pynini.union(*labels_minute_single) @ cardinal
graph_minute_double = pynini.union(*labels_minute_double) @ cardinal
final_graph_hour = pynutil.insert("hours: \"") + graph_hour + pynutil.insert("\"")
final_graph_minute = (
pynutil.insert("minutes: \"")
+ (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double)
+ pynutil.insert("\"")
)
final_graph_second = (
pynutil.insert("seconds: \"")
+ (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double)
+ pynutil.insert("\"")
)
final_suffix = pynutil.insert("suffix: \"") + convert_space(suffix_graph) + pynutil.insert("\"")
final_suffix_optional = pynini.closure(delete_space + insert_space + final_suffix, 0, 1)
final_time_zone_optional = pynini.closure(
delete_space
+ insert_space
+ pynutil.insert("zone: \"")
+ convert_space(time_zone_graph)
+ pynutil.insert("\""),
0,
1,
)
# 2:30 pm, 02:30, 2:00
graph_hm = (
final_graph_hour
+ pynutil.delete(":")
+ (pynutil.delete("00") | insert_space + final_graph_minute)
+ final_suffix_optional
+ final_time_zone_optional
)
# 10:30:05 pm,
graph_hms = (
final_graph_hour
+ pynutil.delete(":")
+ (pynini.cross("00", " minutes: \"zero\"") | insert_space + final_graph_minute)
+ pynutil.delete(":")
+ (pynini.cross("00", " seconds: \"zero\"") | insert_space + final_graph_second)
+ final_suffix_optional
+ final_time_zone_optional
)
# 2.xx pm/am
graph_hm2 = (
final_graph_hour
+ pynutil.delete(".")
+ (pynutil.delete("00") | insert_space + final_graph_minute)
+ delete_space
+ insert_space
+ final_suffix
+ final_time_zone_optional
)
# 2 pm est
graph_h = final_graph_hour + delete_space + insert_space + final_suffix + final_time_zone_optional
final_graph = (graph_hm | graph_h | graph_hm2 | graph_hms).optimize()
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@@ -0,0 +1,201 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_WHITE_SPACE,
GraphFst,
delete_extra_space,
delete_space,
generator_main,
)
from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDateFst
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinalFst
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTimeFst
from pynini.lib import pynutil
class ClassifyFst(GraphFst):
"""
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
input_case: accepting either "lower_cased" or "cased" input.
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
"""
def __init__(
self,
input_case: str,
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
):
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
far_file = None
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
whitelist_file = os.path.basename(whitelist) if whitelist else ""
far_file = os.path.join(
cache_dir, f"en_tn_{deterministic}_deterministic_{input_case}_{whitelist_file}_tokenize.far"
)
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode="r")["tokenize_and_classify"]
else:
start_time = time.time()
cardinal = CardinalFst(deterministic=deterministic)
cardinal_graph = cardinal.fst
start_time = time.time()
ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic)
ordinal_graph = ordinal.fst
start_time = time.time()
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
decimal_graph = decimal.fst
start_time = time.time()
fraction = FractionFst(deterministic=deterministic, cardinal=cardinal)
fraction_graph = fraction.fst
start_time = time.time()
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic)
measure_graph = measure.fst
start_time = time.time()
date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst
start_time = time.time()
time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst
start_time = time.time()
telephone_graph = TelephoneFst(deterministic=deterministic).fst
start_time = time.time()
electonic_graph = ElectronicFst(deterministic=deterministic).fst
start_time = time.time()
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst
start_time = time.time()
whitelist_graph = WhiteListFst(
input_case=input_case, deterministic=deterministic, input_file=whitelist
).fst
start_time = time.time()
punctuation = PunctuationFst(deterministic=deterministic)
punct_graph = punctuation.fst
start_time = time.time()
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).fst
start_time = time.time()
serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic).fst
start_time = time.time()
v_time_graph = vTimeFst(deterministic=deterministic).fst
v_ordinal_graph = vOrdinalFst(deterministic=deterministic)
v_date_graph = vDateFst(ordinal=v_ordinal_graph, deterministic=deterministic).fst
time_final = pynini.compose(time_graph, v_time_graph)
date_final = pynini.compose(date_graph, v_date_graph)
range_graph = RangeFst(
time=time_final, date=date_final, cardinal=cardinal, deterministic=deterministic
).fst
classify = (
pynutil.add_weight(whitelist_graph, 1.01)
| pynutil.add_weight(time_graph, 1.1)
| pynutil.add_weight(date_graph, 1.09)
| pynutil.add_weight(decimal_graph, 1.1)
| pynutil.add_weight(measure_graph, 1.1)
| pynutil.add_weight(cardinal_graph, 1.1)
| pynutil.add_weight(ordinal_graph, 1.1)
| pynutil.add_weight(money_graph, 1.1)
| pynutil.add_weight(telephone_graph, 1.1)
| pynutil.add_weight(electonic_graph, 1.1)
| pynutil.add_weight(fraction_graph, 1.1)
| pynutil.add_weight(range_graph, 1.1)
| pynutil.add_weight(serial_graph, 1.1001) # should be higher than the rest of the classes
)
roman_graph = RomanFst(deterministic=deterministic).fst
classify |= pynutil.add_weight(roman_graph, 1.1)
if not deterministic:
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
classify |= pynutil.add_weight(abbreviation_graph, 100)
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=2.1) + pynutil.insert(" }")
punct = pynini.closure(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct),
1,
)
classify |= pynutil.add_weight(word_graph, 100)
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
)
graph = token_plus_punct + pynini.closure(
(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
)
+ token_plus_punct
)
graph = delete_space + graph + delete_space
graph |= punct
self.fst = graph.optimize()
if far_file:
generator_main(far_file, {"tokenize_and_classify": self.fst})

View File

@@ -0,0 +1,228 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_CHAR,
NEMO_DIGIT,
NEMO_NOT_SPACE,
NEMO_SIGMA,
NEMO_WHITE_SPACE,
GraphFst,
delete_extra_space,
delete_space,
generator_main,
)
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord
from pynini.examples import plurals
from pynini.lib import pynutil
from nemo.utils import logging
class ClassifyFst(GraphFst):
"""
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
input_case: accepting either "lower_cased" or "cased" input.
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
"""
def __init__(
self,
input_case: str,
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = True,
whitelist: str = None,
):
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
far_file = None
if cache_dir is not None and cache_dir != 'None':
os.makedirs(cache_dir, exist_ok=True)
whitelist_file = os.path.basename(whitelist) if whitelist else ""
far_file = os.path.join(
cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}_lm.far"
)
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify']
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
logging.info(f'ClassifyFst.fst was restored from {far_file}.')
else:
logging.info(f'Creating ClassifyFst grammars. This might take some time...')
# TAGGERS
cardinal = CardinalFst(deterministic=True, lm=True)
cardinal_tagger = cardinal
cardinal_graph = cardinal.fst
ordinal = OrdinalFst(cardinal=cardinal, deterministic=True)
ordinal_graph = ordinal.fst
decimal = DecimalFst(cardinal=cardinal, deterministic=True)
decimal_graph = decimal.fst
fraction = FractionFst(deterministic=True, cardinal=cardinal)
fraction_graph = fraction.fst
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=True)
measure_graph = measure.fst
date = DateFst(cardinal=cardinal, deterministic=True, lm=True)
date_graph = date.fst
punctuation = PunctuationFst(deterministic=True)
punct_graph = punctuation.graph
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph
time_graph = TimeFst(cardinal=cardinal, deterministic=True).fst
telephone_graph = TelephoneFst(deterministic=True).fst
electronic_graph = ElectronicFst(deterministic=True).fst
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=False).fst
whitelist = WhiteListFst(input_case=input_case, deterministic=False, input_file=whitelist)
whitelist_graph = whitelist.graph
serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic, lm=True).fst
# VERBALIZERS
cardinal = vCardinal(deterministic=True)
v_cardinal_graph = cardinal.fst
decimal = vDecimal(cardinal=cardinal, deterministic=True)
v_decimal_graph = decimal.fst
ordinal = vOrdinal(deterministic=True)
v_ordinal_graph = ordinal.fst
fraction = vFraction(deterministic=True, lm=True)
v_fraction_graph = fraction.fst
v_telephone_graph = vTelephone(deterministic=True).fst
v_electronic_graph = vElectronic(deterministic=True).fst
measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=False)
v_measure_graph = measure.fst
v_time_graph = vTime(deterministic=True).fst
v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic, lm=True).fst
v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst
v_roman_graph = vRoman(deterministic=deterministic).fst
v_word_graph = vWord(deterministic=deterministic).fst
cardinal_or_date_final = plurals._priority_union(date_graph, cardinal_graph, NEMO_SIGMA)
cardinal_or_date_final = pynini.compose(cardinal_or_date_final, (v_cardinal_graph | v_date_graph))
time_final = pynini.compose(time_graph, v_time_graph)
ordinal_final = pynini.compose(ordinal_graph, v_ordinal_graph)
sem_w = 1
word_w = 100
punct_w = 2
classify_and_verbalize = (
pynutil.add_weight(time_final, sem_w)
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w)
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w)
| pynutil.add_weight(ordinal_final, sem_w)
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w)
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w)
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w)
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w)
| pynutil.add_weight(cardinal_or_date_final, sem_w)
| pynutil.add_weight(whitelist_graph, sem_w)
| pynutil.add_weight(
pynini.compose(serial_graph, v_word_graph), 1.1001
) # should be higher than the rest of the classes
).optimize()
roman_graph = RomanFst(deterministic=deterministic, lm=True).fst
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), sem_w)
date_final = pynini.compose(date_graph, v_date_graph)
range_graph = RangeFst(
time=time_final, cardinal=cardinal_tagger, date=date_final, deterministic=deterministic
).fst
classify_and_verbalize |= pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w)
classify_and_verbalize = pynutil.insert("< ") + classify_and_verbalize + pynutil.insert(" >")
classify_and_verbalize |= pynutil.add_weight(word_graph, word_w)
punct_only = pynutil.add_weight(punct_graph, weight=punct_w)
punct = pynini.closure(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct_only),
1,
)
def get_token_sem_graph(classify_and_verbalize):
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" "))
+ classify_and_verbalize
+ pynini.closure(pynutil.insert(" ") + punct)
)
graph = token_plus_punct + pynini.closure(
(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
)
+ token_plus_punct
)
graph |= punct_only + pynini.closure(punct)
graph = delete_space + graph + delete_space
remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure(
delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)
)
remove_extra_spaces |= (
pynini.closure(pynutil.delete(" "), 1)
+ pynini.closure(NEMO_NOT_SPACE, 1)
+ pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1))
)
graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize()
return graph
self.fst = get_token_sem_graph(classify_and_verbalize)
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
if far_file:
generator_main(far_file, {"tokenize_and_classify": self.fst})
logging.info(f'ClassifyFst grammars are saved to {far_file}.')

View File

@@ -0,0 +1,229 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_CHAR,
NEMO_DIGIT,
NEMO_NOT_SPACE,
NEMO_WHITE_SPACE,
GraphFst,
delete_extra_space,
delete_space,
generator_main,
)
from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst as vAbbreviation
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord
from pynini.lib import pynutil
from nemo.utils import logging
class ClassifyFst(GraphFst):
"""
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
input_case: accepting either "lower_cased" or "cased" input.
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
"""
def __init__(
self,
input_case: str,
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = True,
whitelist: str = None,
):
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
far_file = None
if cache_dir is not None and cache_dir != 'None':
os.makedirs(cache_dir, exist_ok=True)
whitelist_file = os.path.basename(whitelist) if whitelist else ""
far_file = os.path.join(
cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}.far"
)
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify']
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
logging.info(f'ClassifyFst.fst was restored from {far_file}.')
else:
logging.info(f'Creating ClassifyFst grammars. This might take some time...')
# TAGGERS
cardinal = CardinalFst(deterministic=deterministic)
cardinal_graph = cardinal.fst
ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic)
deterministic_ordinal = OrdinalFst(cardinal=cardinal, deterministic=True)
ordinal_graph = ordinal.fst
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
decimal_graph = decimal.fst
fraction = FractionFst(deterministic=deterministic, cardinal=cardinal)
fraction_graph = fraction.fst
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic)
measure_graph = measure.fst
date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst
punctuation = PunctuationFst(deterministic=True)
punct_graph = punctuation.graph
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph
time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst
telephone_graph = TelephoneFst(deterministic=deterministic).fst
electronic_graph = ElectronicFst(deterministic=deterministic).fst
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst
whitelist = WhiteListFst(input_case=input_case, deterministic=deterministic, input_file=whitelist)
whitelist_graph = whitelist.graph
serial_graph = SerialFst(cardinal=cardinal, ordinal=deterministic_ordinal, deterministic=deterministic).fst
# VERBALIZERS
cardinal = vCardinal(deterministic=deterministic)
v_cardinal_graph = cardinal.fst
decimal = vDecimal(cardinal=cardinal, deterministic=deterministic)
v_decimal_graph = decimal.fst
ordinal = vOrdinal(deterministic=deterministic)
v_ordinal_graph = ordinal.fst
fraction = vFraction(deterministic=deterministic)
v_fraction_graph = fraction.fst
v_telephone_graph = vTelephone(deterministic=deterministic).fst
v_electronic_graph = vElectronic(deterministic=deterministic).fst
measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic)
v_measure_graph = measure.fst
v_time_graph = vTime(deterministic=deterministic).fst
v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic).fst
v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst
v_roman_graph = vRoman(deterministic=deterministic).fst
v_abbreviation = vAbbreviation(deterministic=deterministic).fst
det_v_time_graph = vTime(deterministic=True).fst
det_v_date_graph = vDate(ordinal=vOrdinal(deterministic=True), deterministic=True).fst
time_final = pynini.compose(time_graph, det_v_time_graph)
date_final = pynini.compose(date_graph, det_v_date_graph)
range_graph = RangeFst(
time=time_final, date=date_final, cardinal=CardinalFst(deterministic=True), deterministic=deterministic
).fst
v_word_graph = vWord(deterministic=deterministic).fst
sem_w = 1
word_w = 100
punct_w = 2
classify_and_verbalize = (
pynutil.add_weight(whitelist_graph, sem_w)
| pynutil.add_weight(pynini.compose(time_graph, v_time_graph), sem_w)
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w)
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w)
| pynutil.add_weight(pynini.compose(cardinal_graph, v_cardinal_graph), sem_w)
| pynutil.add_weight(pynini.compose(ordinal_graph, v_ordinal_graph), sem_w)
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w)
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w)
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w)
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w)
| pynutil.add_weight(word_graph, word_w)
| pynutil.add_weight(pynini.compose(date_graph, v_date_graph), sem_w - 0.01)
| pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w)
| pynutil.add_weight(
pynini.compose(serial_graph, v_word_graph), 1.1001
) # should be higher than the rest of the classes
).optimize()
if not deterministic:
roman_graph = RomanFst(deterministic=deterministic).fst
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), word_w)
abbreviation_graph = AbbreviationFst(whitelist=whitelist, deterministic=deterministic).fst
classify_and_verbalize |= pynutil.add_weight(
pynini.compose(abbreviation_graph, v_abbreviation), word_w
)
punct_only = pynutil.add_weight(punct_graph, weight=punct_w)
punct = pynini.closure(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct_only),
1,
)
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" "))
+ classify_and_verbalize
+ pynini.closure(pynutil.insert(" ") + punct)
)
graph = token_plus_punct + pynini.closure(
(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
)
+ token_plus_punct
)
graph |= punct_only + pynini.closure(punct)
graph = delete_space + graph + delete_space
remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure(
delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)
)
remove_extra_spaces |= (
pynini.closure(pynutil.delete(" "), 1)
+ pynini.closure(NEMO_NOT_SPACE, 1)
+ pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1))
)
graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize()
self.fst = graph
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
self.fst_no_digits = pynini.compose(graph, no_digits).optimize()
if far_file:
generator_main(far_file, {"tokenize_and_classify": self.fst})
logging.info(f'ClassifyFst grammars are saved to {far_file}.')

View File

@@ -0,0 +1,151 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_CHAR,
NEMO_NOT_SPACE,
NEMO_SIGMA,
NEMO_UPPER,
SINGULAR_TO_PLURAL,
GraphFst,
convert_space,
)
from nemo_text_processing.text_normalization.en.taggers.roman import get_names
from nemo_text_processing.text_normalization.en.utils import (
augment_labels_with_punct_at_end,
get_abs_path,
load_labels,
)
from pynini.lib import pynutil
class WhiteListFst(GraphFst):
"""
Finite state transducer for classifying whitelist, e.g.
misses -> tokens { name: "mrs" }
for non-deterministic case: "Dr. Abc" ->
tokens { name: "drive" } tokens { name: "Abc" }
tokens { name: "doctor" } tokens { name: "Abc" }
tokens { name: "Dr." } tokens { name: "Abc" }
This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv".
Args:
input_case: accepting either "lower_cased" or "cased" input.
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
input_file: path to a file with whitelist replacements
"""
def __init__(self, input_case: str, deterministic: bool = True, input_file: str = None):
super().__init__(name="whitelist", kind="classify", deterministic=deterministic)
def _get_whitelist_graph(input_case, file, keep_punct_add_end: bool = False):
whitelist = load_labels(file)
if input_case == "lower_cased":
whitelist = [[x.lower(), y] for x, y in whitelist]
else:
whitelist = [[x, y] for x, y in whitelist]
if keep_punct_add_end:
whitelist.extend(augment_labels_with_punct_at_end(whitelist))
graph = pynini.string_map(whitelist)
return graph
graph = _get_whitelist_graph(input_case, get_abs_path("data/whitelist/tts.tsv"))
graph |= _get_whitelist_graph(input_case, get_abs_path("data/whitelist/UK_to_US.tsv")) # Jiayu 2022.10
graph |= pynini.compose(
pynini.difference(NEMO_SIGMA, pynini.accep("/")).optimize(),
_get_whitelist_graph(input_case, get_abs_path("data/whitelist/symbol.tsv")),
).optimize()
if deterministic:
names = get_names()
graph |= (
pynini.cross(pynini.union("st", "St", "ST"), "Saint")
+ pynini.closure(pynutil.delete("."))
+ pynini.accep(" ")
+ names
)
else:
graph |= _get_whitelist_graph(
input_case, get_abs_path("data/whitelist/alternatives.tsv"), keep_punct_add_end=True
)
for x in [".", ". "]:
graph |= (
NEMO_UPPER
+ pynini.closure(pynutil.delete(x) + NEMO_UPPER, 2)
+ pynini.closure(pynutil.delete("."), 0, 1)
)
if not deterministic:
multiple_forms_whitelist_graph = get_formats(get_abs_path("data/whitelist/alternatives_all_format.tsv"))
graph |= multiple_forms_whitelist_graph
graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv")) | pynini.string_file(
get_abs_path("data/measure/unit_alternatives.tsv")
)
graph_unit_plural = graph_unit @ SINGULAR_TO_PLURAL
units_graph = pynini.compose(NEMO_CHAR ** (3, ...), convert_space(graph_unit | graph_unit_plural))
graph |= units_graph
# convert to states only if comma is present before the abbreviation to avoid converting all caps words,
# e.g. "IN", "OH", "OK"
# TODO or only exclude above?
states = load_labels(get_abs_path("data/address/state.tsv"))
additional_options = []
for x, y in states:
if input_case == "lower_cased":
x = x.lower()
additional_options.append((x, f"{y[0]}.{y[1:]}"))
if not deterministic:
additional_options.append((x, f"{y[0]}.{y[1:]}."))
states.extend(additional_options)
state_graph = pynini.string_map(states)
graph |= pynini.closure(NEMO_NOT_SPACE, 1) + pynini.union(", ", ",") + pynini.invert(state_graph).optimize()
if input_file:
whitelist_provided = _get_whitelist_graph(input_case, input_file)
if not deterministic:
graph |= whitelist_provided
else:
graph = whitelist_provided
self.graph = (convert_space(graph)).optimize()
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()
def get_formats(input_f, input_case="cased", is_default=True):
"""
Adds various abbreviation format options to the list of acceptable input forms
"""
multiple_formats = load_labels(input_f)
additional_options = []
for x, y in multiple_formats:
if input_case == "lower_cased":
x = x.lower()
additional_options.append((f"{x}.", y)) # default "dr" -> doctor, this includes period "dr." -> doctor
additional_options.append((f"{x[0].upper() + x[1:]}", f"{y[0].upper() + y[1:]}")) # "Dr" -> Doctor
additional_options.append((f"{x[0].upper() + x[1:]}.", f"{y[0].upper() + y[1:]}")) # "Dr." -> Doctor
multiple_formats.extend(additional_options)
if not is_default:
multiple_formats = [(x, f"|raw_start|{x}|raw_end||norm_start|{y}|norm_end|") for (x, y) in multiple_formats]
multiple_formats = pynini.string_map(multiple_formats)
return multiple_formats

View File

@@ -0,0 +1,90 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
MIN_NEG_WEIGHT,
NEMO_ALPHA,
NEMO_DIGIT,
NEMO_NOT_SPACE,
NEMO_SIGMA,
GraphFst,
convert_space,
get_abs_path,
)
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
from pynini.examples import plurals
from pynini.lib import pynutil
class WordFst(GraphFst):
"""
Finite state transducer for classifying word. Considers sentence boundary exceptions.
e.g. sleep -> tokens { name: "sleep" }
Args:
punctuation: PunctuationFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, punctuation: GraphFst, deterministic: bool = True):
super().__init__(name="word", kind="classify", deterministic=deterministic)
punct = PunctuationFst().graph
default_graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, punct.project("input")), 1)
symbols_to_exclude = (pynini.union("$", "", "", "£", "¥", "#", "%") | NEMO_DIGIT).optimize()
graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, symbols_to_exclude), 1)
graph = pynutil.add_weight(graph, MIN_NEG_WEIGHT) | default_graph
# leave phones of format [HH AH0 L OW1] untouched
phoneme_unit = pynini.closure(NEMO_ALPHA, 1) + pynini.closure(NEMO_DIGIT)
phoneme = (
pynini.accep(pynini.escape("["))
+ pynini.closure(phoneme_unit + pynini.accep(" "))
+ phoneme_unit
+ pynini.accep(pynini.escape("]"))
)
# leave IPA phones of format [ˈdoʊv] untouched, single words and sentences with punctuation marks allowed
punct_marks = pynini.union(*punctuation.punct_marks).optimize()
stress = pynini.union("ˈ", "'", "ˌ")
ipa_phoneme_unit = pynini.string_file(get_abs_path("data/whitelist/ipa_symbols.tsv"))
# word in ipa form
ipa_phonemes = (
pynini.closure(stress, 0, 1)
+ pynini.closure(ipa_phoneme_unit, 1)
+ pynini.closure(stress | ipa_phoneme_unit)
)
# allow sentences of words in IPA format separated with spaces or punct marks
delim = (punct_marks | pynini.accep(" ")) ** (1, ...)
ipa_phonemes = ipa_phonemes + pynini.closure(delim + ipa_phonemes) + pynini.closure(delim, 0, 1)
ipa_phonemes = (pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]"))).optimize()
if not deterministic:
phoneme = (
pynini.accep(pynini.escape("["))
+ pynini.closure(pynini.accep(" "), 0, 1)
+ pynini.closure(phoneme_unit + pynini.accep(" "))
+ phoneme_unit
+ pynini.closure(pynini.accep(" "), 0, 1)
+ pynini.accep(pynini.escape("]"))
).optimize()
ipa_phonemes = (
pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]"))
).optimize()
phoneme |= ipa_phonemes
self.graph = plurals._priority_union(convert_space(phoneme.optimize()), graph, NEMO_SIGMA)
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()

View File

@@ -0,0 +1,60 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
def get_abs_path(rel_path):
"""
Get absolute path
Args:
rel_path: relative path to this file
Returns absolute path
"""
return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
def load_labels(abs_path):
"""
loads relative path file as dictionary
Args:
abs_path: absolute path
Returns dictionary of mappings
"""
label_tsv = open(abs_path, encoding="utf-8")
labels = list(csv.reader(label_tsv, delimiter="\t"))
return labels
def augment_labels_with_punct_at_end(labels):
"""
augments labels: if key ends on a punctuation that value does not have, add a new label
where the value maintains the punctuation
Args:
labels : input labels
Returns:
additional labels
"""
res = []
for label in labels:
if len(label) > 1:
if label[0][-1] == "." and label[1][-1] != ".":
res.append([label[0], label[1] + "."] + label[2:])
return res

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst
from pynini.lib import pynutil
class AbbreviationFst(GraphFst):
"""
Finite state transducer for verbalizing abbreviations
e.g. tokens { abbreviation { value: "A B C" } } -> "ABC"
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="abbreviation", kind="verbalize", deterministic=deterministic)
graph = pynutil.delete("value: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,45 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space
from pynini.lib import pynutil
class CardinalFst(GraphFst):
"""
Finite state transducer for verbalizing cardinal, e.g.
cardinal { negative: "true" integer: "23" } -> minus twenty three
Args:
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="cardinal", kind="verbalize", deterministic=deterministic)
self.optional_sign = pynini.cross("negative: \"true\"", "minus ")
if not deterministic:
self.optional_sign |= pynini.cross("negative: \"true\"", "negative ")
self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1)
integer = pynini.closure(NEMO_NOT_QUOTE)
self.integer = delete_space + pynutil.delete("\"") + integer + pynutil.delete("\"")
integer = pynutil.delete("integer:") + self.integer
self.numbers = self.optional_sign + integer
delete_tokens = self.delete_tokens(self.numbers)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,101 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_NOT_QUOTE,
NEMO_SIGMA,
GraphFst,
delete_extra_space,
delete_space,
)
from pynini.examples import plurals
from pynini.lib import pynutil
class DateFst(GraphFst):
"""
Finite state transducer for verbalizing date, e.g.
date { month: "february" day: "five" year: "twenty twelve" preserve_order: true } -> february fifth twenty twelve
date { day: "five" month: "february" year: "twenty twelve" preserve_order: true } -> the fifth of february twenty twelve
Args:
ordinal: OrdinalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, ordinal: GraphFst, deterministic: bool = True, lm: bool = False):
super().__init__(name="date", kind="verbalize", deterministic=deterministic)
month = pynini.closure(NEMO_NOT_QUOTE, 1)
day_cardinal = (
pynutil.delete("day:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
day = day_cardinal @ ordinal.suffix
month = pynutil.delete("month:") + delete_space + pynutil.delete("\"") + month + pynutil.delete("\"")
year = (
pynutil.delete("year:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ delete_space
+ pynutil.delete("\"")
)
# month (day) year
graph_mdy = (
month + pynini.closure(delete_extra_space + day, 0, 1) + pynini.closure(delete_extra_space + year, 0, 1)
)
# may 5 -> may five
if not deterministic and not lm:
graph_mdy |= (
month
+ pynini.closure(delete_extra_space + day_cardinal, 0, 1)
+ pynini.closure(delete_extra_space + year, 0, 1)
)
# day month year
graph_dmy = (
pynutil.insert("the ")
+ day
+ delete_extra_space
+ pynutil.insert("of ")
+ month
+ pynini.closure(delete_extra_space + year, 0, 1)
)
optional_preserve_order = pynini.closure(
pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
| pynutil.delete("field_order:")
+ delete_space
+ pynutil.delete("\"")
+ NEMO_NOT_QUOTE
+ pynutil.delete("\"")
+ delete_space
)
final_graph = (
(plurals._priority_union(graph_mdy, pynutil.add_weight(graph_dmy, 0.0001), NEMO_SIGMA) | year)
+ delete_space
+ optional_preserve_order
)
delete_tokens = self.delete_tokens(final_graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,67 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
from pynini.lib import pynutil
class DecimalFst(GraphFst):
"""
Finite state transducer for verbalizing decimal, e.g.
decimal { negative: "true" integer_part: "twelve" fractional_part: "five o o six" quantity: "billion" } -> minus twelve point five o o six billion
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, cardinal, deterministic: bool = True):
super().__init__(name="decimal", kind="verbalize", deterministic=deterministic)
self.optional_sign = pynini.cross("negative: \"true\"", "minus ")
if not deterministic:
self.optional_sign |= pynini.cross("negative: \"true\"", "negative ")
self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1)
self.integer = pynutil.delete("integer_part:") + cardinal.integer
self.optional_integer = pynini.closure(self.integer + delete_space + insert_space, 0, 1)
self.fractional_default = (
pynutil.delete("fractional_part:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
self.fractional = pynutil.insert("point ") + self.fractional_default
self.quantity = (
delete_space
+ insert_space
+ pynutil.delete("quantity:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
self.optional_quantity = pynini.closure(self.quantity, 0, 1)
graph = self.optional_sign + (
self.integer
| (self.integer + self.quantity)
| (self.optional_integer + self.fractional + self.optional_quantity)
)
self.numbers = graph
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,97 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_NOT_QUOTE,
NEMO_NOT_SPACE,
NEMO_SIGMA,
TO_UPPER,
GraphFst,
delete_extra_space,
delete_space,
insert_space,
)
from nemo_text_processing.text_normalization.en.utils import get_abs_path
from pynini.examples import plurals
from pynini.lib import pynutil
class ElectronicFst(GraphFst):
"""
Finite state transducer for verbalizing electronic
e.g. tokens { electronic { username: "cdf1" domain: "abc.edu" } } -> c d f one at a b c dot e d u
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="electronic", kind="verbalize", deterministic=deterministic)
graph_digit_no_zero = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize()
graph_zero = pynini.cross("0", "zero")
if not deterministic:
graph_zero |= pynini.cross("0", "o") | pynini.cross("0", "oh")
graph_digit = graph_digit_no_zero | graph_zero
graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize()
default_chars_symbols = pynini.cdrewrite(
pynutil.insert(" ") + (graph_symbols | graph_digit) + pynutil.insert(" "), "", "", NEMO_SIGMA
)
default_chars_symbols = pynini.compose(
pynini.closure(NEMO_NOT_SPACE), default_chars_symbols.optimize()
).optimize()
user_name = (
pynutil.delete("username:")
+ delete_space
+ pynutil.delete("\"")
+ default_chars_symbols
+ pynutil.delete("\"")
)
domain_common = pynini.string_file(get_abs_path("data/electronic/domain.tsv"))
domain = (
default_chars_symbols
+ insert_space
+ plurals._priority_union(
domain_common, pynutil.add_weight(pynini.cross(".", "dot"), weight=0.0001), NEMO_SIGMA
)
+ pynini.closure(
insert_space + (pynini.cdrewrite(TO_UPPER, "", "", NEMO_SIGMA) @ default_chars_symbols), 0, 1
)
)
domain = (
pynutil.delete("domain:")
+ delete_space
+ pynutil.delete("\"")
+ domain
+ delete_space
+ pynutil.delete("\"")
).optimize()
protocol = pynutil.delete("protocol: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
graph = (
pynini.closure(protocol + delete_space, 0, 1)
+ pynini.closure(user_name + delete_space + pynutil.insert(" at ") + delete_space, 0, 1)
+ domain
+ delete_space
).optimize() @ pynini.cdrewrite(delete_extra_space, "", "", NEMO_SIGMA)
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, insert_space
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
from pynini.examples import plurals
from pynini.lib import pynutil
class FractionFst(GraphFst):
"""
Finite state transducer for verbalizing fraction
e.g. tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } } ->
twenty three and four fifth
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True, lm: bool = False):
super().__init__(name="fraction", kind="verbalize", deterministic=deterministic)
suffix = OrdinalFst().suffix
integer = pynutil.delete("integer_part: \"") + pynini.closure(NEMO_NOT_QUOTE) + pynutil.delete("\" ")
denominator_one = pynini.cross("denominator: \"one\"", "over one")
denominator_half = pynini.cross("denominator: \"two\"", "half")
denominator_quarter = pynini.cross("denominator: \"four\"", "quarter")
denominator_rest = (
pynutil.delete("denominator: \"") + pynini.closure(NEMO_NOT_QUOTE) @ suffix + pynutil.delete("\"")
)
denominators = plurals._priority_union(
denominator_one,
plurals._priority_union(
denominator_half,
plurals._priority_union(denominator_quarter, denominator_rest, NEMO_SIGMA),
NEMO_SIGMA,
),
NEMO_SIGMA,
).optimize()
if not deterministic:
denominators |= pynutil.delete("denominator: \"") + (pynini.accep("four") @ suffix) + pynutil.delete("\"")
numerator_one = pynutil.delete("numerator: \"") + pynini.accep("one") + pynutil.delete("\" ")
numerator_one = numerator_one + insert_space + denominators
numerator_rest = (
pynutil.delete("numerator: \"")
+ (pynini.closure(NEMO_NOT_QUOTE) - pynini.accep("one"))
+ pynutil.delete("\" ")
)
numerator_rest = numerator_rest + insert_space + denominators
numerator_rest @= pynini.cdrewrite(
plurals._priority_union(pynini.cross("half", "halves"), pynutil.insert("s"), NEMO_SIGMA),
"",
"[EOS]",
NEMO_SIGMA,
)
graph = numerator_one | numerator_rest
conjunction = pynutil.insert("and ")
if not deterministic and not lm:
conjunction = pynini.closure(conjunction, 0, 1)
integer = pynini.closure(integer + insert_space + conjunction, 0, 1)
graph = integer + graph
graph @= pynini.cdrewrite(
pynini.cross("and one half", "and a half") | pynini.cross("over ones", "over one"), "", "[EOS]", NEMO_SIGMA
)
self.graph = graph
delete_tokens = self.delete_tokens(self.graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,102 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
from pynini.lib import pynutil
class MeasureFst(GraphFst):
"""
Finite state transducer for verbalizing measure, e.g.
measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" } -> minus twelve kilograms
measure { decimal { integer_part: "twelve" fractional_part: "five" } units: "kilograms" } -> twelve point five kilograms
tokens { measure { units: "covid" decimal { integer_part: "nineteen" fractional_part: "five" } } } -> covid nineteen point five
Args:
decimal: DecimalFst
cardinal: CardinalFst
fraction: FractionFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, decimal: GraphFst, cardinal: GraphFst, fraction: GraphFst, deterministic: bool = True):
super().__init__(name="measure", kind="verbalize", deterministic=deterministic)
optional_sign = cardinal.optional_sign
unit = (
pynutil.delete("units: \"")
+ pynini.difference(pynini.closure(NEMO_NOT_QUOTE, 1), pynini.union("address", "math"))
+ pynutil.delete("\"")
+ delete_space
)
if not deterministic:
unit |= pynini.compose(unit, pynini.cross(pynini.union("inch", "inches"), "\""))
graph_decimal = (
pynutil.delete("decimal {")
+ delete_space
+ optional_sign
+ delete_space
+ decimal.numbers
+ delete_space
+ pynutil.delete("}")
)
graph_cardinal = (
pynutil.delete("cardinal {")
+ delete_space
+ optional_sign
+ delete_space
+ cardinal.numbers
+ delete_space
+ pynutil.delete("}")
)
graph_fraction = (
pynutil.delete("fraction {") + delete_space + fraction.graph + delete_space + pynutil.delete("}")
)
graph = (graph_cardinal | graph_decimal | graph_fraction) + delete_space + insert_space + unit
# SH adds "preserve_order: true" by default
preserve_order = pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
graph |= unit + insert_space + (graph_cardinal | graph_decimal) + delete_space + pynini.closure(preserve_order)
# for only unit
graph |= (
pynutil.delete("cardinal { integer: \"-\"")
+ delete_space
+ pynutil.delete("}")
+ delete_space
+ unit
+ pynini.closure(preserve_order)
)
address = (
pynutil.delete("units: \"address\" ")
+ delete_space
+ graph_cardinal
+ delete_space
+ pynini.closure(preserve_order)
)
math = (
pynutil.delete("units: \"math\" ")
+ delete_space
+ graph_cardinal
+ delete_space
+ pynini.closure(preserve_order)
)
graph |= address | math
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_NOT_QUOTE,
GraphFst,
delete_extra_space,
delete_preserve_order,
)
from pynini.lib import pynutil
class MoneyFst(GraphFst):
"""
Finite state transducer for verbalizing money, e.g.
money { integer_part: "twelve" fractional_part: "o five" currency: "dollars" } -> twelve o five dollars
Args:
decimal: DecimalFst
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, decimal: GraphFst, deterministic: bool = True):
super().__init__(name="money", kind="verbalize", deterministic=deterministic)
keep_space = pynini.accep(" ")
maj = pynutil.delete("currency_maj: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
min = pynutil.delete("currency_min: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
fractional_part = (
pynutil.delete("fractional_part: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
)
integer_part = decimal.integer
# *** currency_maj
graph_integer = integer_part + keep_space + maj
# *** currency_maj + (***) | ((and) *** current_min)
fractional = fractional_part + delete_extra_space + min
if not deterministic:
fractional |= pynutil.insert("and ") + fractional
graph_integer_with_minor = integer_part + keep_space + maj + keep_space + fractional + delete_preserve_order
# *** point *** currency_maj
graph_decimal = decimal.numbers + keep_space + maj
# *** current_min
graph_minor = fractional_part + delete_extra_space + min + delete_preserve_order
graph = graph_integer | graph_integer_with_minor | graph_decimal | graph_minor
if not deterministic:
graph |= graph_integer + delete_preserve_order
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,53 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, delete_space
from nemo_text_processing.text_normalization.en.utils import get_abs_path
from pynini.lib import pynutil
class OrdinalFst(GraphFst):
"""
Finite state transducer for verbalizing ordinal, e.g.
ordinal { integer: "thirteen" } } -> thirteenth
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="ordinal", kind="verbalize", deterministic=deterministic)
graph_digit = pynini.string_file(get_abs_path("data/ordinal/digit.tsv")).invert()
graph_teens = pynini.string_file(get_abs_path("data/ordinal/teen.tsv")).invert()
graph = (
pynutil.delete("integer:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
convert_rest = pynutil.insert("th")
suffix = pynini.cdrewrite(
graph_digit | graph_teens | pynini.cross("ty", "tieth") | convert_rest, "", "[EOS]", NEMO_SIGMA,
).optimize()
self.graph = pynini.compose(graph, suffix)
self.suffix = suffix
delete_tokens = self.delete_tokens(self.graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
MIN_NEG_WEIGHT,
NEMO_ALPHA,
NEMO_CHAR,
NEMO_SIGMA,
NEMO_SPACE,
generator_main,
)
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
from pynini.lib import pynutil
class PostProcessingFst:
"""
Finite state transducer that post-processing an entire sentence after verbalization is complete, e.g.
removes extra spaces around punctuation marks " ( one hundred and twenty three ) " -> "(one hundred and twenty three)"
Args:
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
"""
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
far_file = None
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
far_file = os.path.join(cache_dir, "en_tn_post_processing.far")
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode="r")["post_process_graph"]
else:
self.set_punct_dict()
self.fst = self.get_punct_postprocess_graph()
if far_file:
generator_main(far_file, {"post_process_graph": self.fst})
def set_punct_dict(self):
self.punct_marks = {
"'": [
"'",
'´',
'ʹ',
'ʻ',
'ʼ',
'ʽ',
'ʾ',
'ˈ',
'ˊ',
'ˋ',
'˴',
'ʹ',
'΄',
'՚',
'՝',
'י',
'׳',
'ߴ',
'ߵ',
'',
'',
'',
'᾿',
'',
'',
'',
'',
'',
'',
'',
'',
'',
'',
'',
'𖽑',
'𖽒',
],
}
def get_punct_postprocess_graph(self):
"""
Returns graph to post process punctuation marks.
{``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
"""
punct_marks_all = PunctuationFst().punct_marks
# no_space_before_punct assume no space before them
quotes = ["'", "\"", "``", "«"]
dashes = ["-", ""]
brackets = ["<", "{", "("]
open_close_single_quotes = [
("`", "`"),
]
open_close_double_quotes = [('"', '"'), ("``", "``"), ("", "")]
open_close_symbols = open_close_single_quotes + open_close_double_quotes
allow_space_before_punct = ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct]
no_space_before_punct = pynini.union(*no_space_before_punct)
no_space_after_punct = pynini.union(*brackets)
delete_space = pynutil.delete(" ")
delete_space_optional = pynini.closure(delete_space, 0, 1)
# non_punct allows space
# delete space before no_space_before_punct marks, if present
non_punct = pynini.difference(NEMO_CHAR, no_space_before_punct).optimize()
graph = (
pynini.closure(non_punct)
+ pynini.closure(
no_space_before_punct | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
)
+ pynini.closure(non_punct)
)
graph = pynini.closure(graph).optimize()
graph = pynini.compose(
graph, pynini.cdrewrite(pynini.cross("``", '"'), "", "", NEMO_SIGMA).optimize()
).optimize()
# remove space after no_space_after_punct (even if there are no matching closing brackets)
no_space_after_punct = pynini.cdrewrite(delete_space, no_space_after_punct, NEMO_SIGMA, NEMO_SIGMA).optimize()
graph = pynini.compose(graph, no_space_after_punct).optimize()
# remove space around text in quotes
single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT)
double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT)
quotes_graph = (
single_quote + delete_space_optional + NEMO_ALPHA + NEMO_SIGMA + delete_space_optional + single_quote
).optimize()
# this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left
not_alpha = pynini.difference(NEMO_CHAR, NEMO_ALPHA).optimize() | pynutil.add_weight(
NEMO_SPACE, MIN_NEG_WEIGHT
)
end = pynini.closure(pynutil.add_weight(not_alpha, MIN_NEG_WEIGHT))
quotes_graph |= (
double_quotes
+ delete_space_optional
+ NEMO_ALPHA
+ NEMO_SIGMA
+ delete_space_optional
+ double_quotes
+ end
)
quotes_graph = pynutil.add_weight(quotes_graph, MIN_NEG_WEIGHT)
quotes_graph = NEMO_SIGMA + pynini.closure(NEMO_SIGMA + quotes_graph + NEMO_SIGMA)
graph = pynini.compose(graph, quotes_graph).optimize()
# remove space between a word and a single quote followed by s
remove_space_around_single_quote = pynini.cdrewrite(
delete_space_optional + pynini.union(*self.punct_marks["'"]) + delete_space,
NEMO_ALPHA,
pynini.union("s ", "s[EOS]"),
NEMO_SIGMA,
)
graph = pynini.compose(graph, remove_space_around_single_quote).optimize()
return graph

View File

@@ -0,0 +1,68 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
from pynini.lib import pynutil
class RomanFst(GraphFst):
"""
Finite state transducer for verbalizing roman numerals
e.g. tokens { roman { integer: "one" } } -> one
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="roman", kind="verbalize", deterministic=deterministic)
suffix = OrdinalFst().suffix
cardinal = pynini.closure(NEMO_NOT_QUOTE)
ordinal = pynini.compose(cardinal, suffix)
graph = (
pynutil.delete("key_cardinal: \"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynini.accep(" ")
+ pynutil.delete("integer: \"")
+ cardinal
+ pynutil.delete("\"")
).optimize()
graph |= (
pynutil.delete("default_cardinal: \"default\" integer: \"") + cardinal + pynutil.delete("\"")
).optimize()
graph |= (
pynutil.delete("default_ordinal: \"default\" integer: \"") + ordinal + pynutil.delete("\"")
).optimize()
graph |= (
pynutil.delete("key_the_ordinal: \"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynini.accep(" ")
+ pynutil.delete("integer: \"")
+ pynini.closure(pynutil.insert("the "), 0, 1)
+ ordinal
+ pynutil.delete("\"")
).optimize()
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
from pynini.lib import pynutil
class TelephoneFst(GraphFst):
"""
Finite state transducer for verbalizing telephone numbers, e.g.
telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" }
-> one, one two three, one two three, five six seven eight, one
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="telephone", kind="verbalize", deterministic=deterministic)
optional_country_code = pynini.closure(
pynutil.delete("country_code: \"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ delete_space
+ insert_space,
0,
1,
)
number_part = (
pynutil.delete("number_part: \"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynini.closure(pynutil.add_weight(pynutil.delete(" "), -0.0001), 0, 1)
+ pynutil.delete("\"")
)
optional_extension = pynini.closure(
delete_space
+ insert_space
+ pynutil.delete("extension: \"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\""),
0,
1,
)
graph = optional_country_code + number_part + optional_extension
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,102 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
NEMO_NOT_QUOTE,
NEMO_SIGMA,
GraphFst,
delete_space,
insert_space,
)
from pynini.lib import pynutil
class TimeFst(GraphFst):
"""
Finite state transducer for verbalizing time, e.g.
time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" } -> twelve thirty a m e s t
time { hours: "twelve" } -> twelve o'clock
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="time", kind="verbalize", deterministic=deterministic)
hour = (
pynutil.delete("hours:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
minute = (
pynutil.delete("minutes:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
suffix = (
pynutil.delete("suffix:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
optional_suffix = pynini.closure(delete_space + insert_space + suffix, 0, 1)
zone = (
pynutil.delete("zone:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
optional_zone = pynini.closure(delete_space + insert_space + zone, 0, 1)
second = (
pynutil.delete("seconds:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
)
graph_hms = (
hour
+ pynutil.insert(" hours ")
+ delete_space
+ minute
+ pynutil.insert(" minutes and ")
+ delete_space
+ second
+ pynutil.insert(" seconds")
+ optional_suffix
+ optional_zone
)
graph_hms @= pynini.cdrewrite(
pynutil.delete("o ")
| pynini.cross("one minutes", "one minute")
| pynini.cross("one seconds", "one second")
| pynini.cross("one hours", "one hour"),
pynini.union(" ", "[BOS]"),
"",
NEMO_SIGMA,
)
graph = hour + delete_space + insert_space + minute + optional_suffix + optional_zone
graph |= hour + insert_space + pynutil.insert("o'clock") + optional_zone
graph |= hour + delete_space + insert_space + suffix + optional_zone
graph |= graph_hms
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo_text_processing.text_normalization.en.graph_utils import GraphFst
from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst
from nemo_text_processing.text_normalization.en.verbalizers.whitelist import WhiteListFst
class VerbalizeFst(GraphFst):
"""
Composes other verbalizer grammars.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="verbalize", kind="verbalize", deterministic=deterministic)
cardinal = CardinalFst(deterministic=deterministic)
cardinal_graph = cardinal.fst
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
decimal_graph = decimal.fst
ordinal = OrdinalFst(deterministic=deterministic)
ordinal_graph = ordinal.fst
fraction = FractionFst(deterministic=deterministic)
fraction_graph = fraction.fst
telephone_graph = TelephoneFst(deterministic=deterministic).fst
electronic_graph = ElectronicFst(deterministic=deterministic).fst
measure = MeasureFst(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic)
measure_graph = measure.fst
time_graph = TimeFst(deterministic=deterministic).fst
date_graph = DateFst(ordinal=ordinal, deterministic=deterministic).fst
money_graph = MoneyFst(decimal=decimal, deterministic=deterministic).fst
whitelist_graph = WhiteListFst(deterministic=deterministic).fst
graph = (
time_graph
| date_graph
| money_graph
| measure_graph
| ordinal_graph
| decimal_graph
| cardinal_graph
| telephone_graph
| electronic_graph
| fraction_graph
| whitelist_graph
)
roman_graph = RomanFst(deterministic=deterministic).fst
graph |= roman_graph
if not deterministic:
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
graph |= abbreviation_graph
self.fst = graph

View File

@@ -0,0 +1,75 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import (
GraphFst,
delete_extra_space,
delete_space,
generator_main,
)
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst
from pynini.lib import pynutil
class VerbalizeFinalFst(GraphFst):
"""
Finite state transducer that verbalizes an entire sentence, e.g.
tokens { name: "its" } tokens { time { hours: "twelve" minutes: "thirty" } } tokens { name: "now" } tokens { name: "." } -> its twelve thirty now .
Args:
deterministic: if True will provide a single transduction option,
for False multiple options (used for audio-based normalization)
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
"""
def __init__(self, deterministic: bool = True, cache_dir: str = None, overwrite_cache: bool = False):
super().__init__(name="verbalize_final", kind="verbalize", deterministic=deterministic)
far_file = None
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
far_file = os.path.join(cache_dir, f"en_tn_{deterministic}_deterministic_verbalizer.far")
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode="r")["verbalize"]
else:
verbalize = VerbalizeFst(deterministic=deterministic).fst
word = WordFst(deterministic=deterministic).fst
types = verbalize | word
if deterministic:
graph = (
pynutil.delete("tokens")
+ delete_space
+ pynutil.delete("{")
+ delete_space
+ types
+ delete_space
+ pynutil.delete("}")
)
else:
graph = delete_space + types + delete_space
graph = delete_space + pynini.closure(graph + delete_extra_space) + graph + delete_space
self.fst = graph.optimize()
if far_file:
generator_main(far_file, {"verbalize": self.fst})

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space
from pynini.lib import pynutil
class WhiteListFst(GraphFst):
"""
Finite state transducer for verbalizing whitelist
e.g. tokens { name: "misses" } } -> misses
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="whitelist", kind="verbalize", deterministic=deterministic)
graph = (
pynutil.delete("name:")
+ delete_space
+ pynutil.delete("\"")
+ pynini.closure(NEMO_CHAR - " ", 1)
+ pynutil.delete("\"")
)
graph = graph @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)
self.fst = graph.optimize()

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pynini
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space
from pynini.lib import pynutil
class WordFst(GraphFst):
"""
Finite state transducer for verbalizing word
e.g. tokens { name: "sleep" } -> sleep
Args:
deterministic: if True will provide a single transduction option,
for False multiple transduction are generated (used for audio-based normalization)
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="word", kind="verbalize", deterministic=deterministic)
chars = pynini.closure(NEMO_CHAR - " ", 1)
char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"")
graph = char @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)
self.fst = graph.optimize()

View File

@@ -0,0 +1,479 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import re
from argparse import ArgumentParser
from collections import OrderedDict
from math import factorial
from time import perf_counter
from typing import Dict, List, Union
import pynini
import regex
from nemo_text_processing.text_normalization.data_loader_utils import (
load_file,
post_process_punct,
pre_process,
write_file,
)
from nemo_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
from pynini.lib.rewrite import top_rewrite
SPACE_DUP = re.compile(' {2,}')
class Normalizer:
"""
Normalizer class that converts text from written to spoken form.
Useful for TTS preprocessing.
Args:
input_case: expected input capitalization
lang: language specifying the TN rules, by default: English
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
Note: punct_post_process flag in normalize() supports all languages.
"""
def __init__(
self,
input_case: str,
lang: str = 'en',
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
lm: bool = False,
post_process: bool = True,
):
assert input_case in ["lower_cased", "cased"]
self.post_processor = None
if lang == "en":
from nemo_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
if post_process:
self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
if deterministic:
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
else:
if lm:
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst
else:
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
ClassifyFst,
)
elif lang == 'ru':
# Ru TN only support non-deterministic cases and produces multiple normalization options
# use normalize_with_audio.py
from nemo_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst
elif lang == 'de':
from nemo_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst
elif lang == 'es':
from nemo_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
self.tagger = ClassifyFst(
input_case=input_case,
deterministic=deterministic,
cache_dir=cache_dir,
overwrite_cache=overwrite_cache,
whitelist=whitelist,
)
self.verbalizer = VerbalizeFinalFst(
deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache
)
self.parser = TokenParser()
self.lang = lang
self.processor = 0
def __process_batch(self, batch, verbose, punct_pre_process, punct_post_process):
"""
Normalizes batch of text sequences
Args:
batch: list of texts
verbose: whether to print intermediate meta information
punct_pre_process: whether to do punctuation pre processing
punct_post_process: whether to do punctuation post processing
"""
normalized_lines = [
self.normalize(
text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process
)
for text in tqdm(batch)
]
return normalized_lines
def _estimate_number_of_permutations_in_nested_dict(
self, token_group: Dict[str, Union[OrderedDict, str, bool]]
) -> int:
num_perms = 1
for k, inner in token_group.items():
if isinstance(inner, dict):
num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner)
num_perms *= factorial(len(token_group))
return num_perms
def _split_tokens_to_reduce_number_of_permutations(
self, tokens: List[dict], max_number_of_permutations_per_split: int = 729
) -> List[List[dict]]:
"""
Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite
tokens permutations does not exceed ``max_number_of_permutations_per_split``.
For example,
.. code-block:: python
tokens = [
{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}},
{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}},
]
split = normalizer._split_tokens_to_reduce_number_of_permutations(
tokens, max_number_of_permutations_per_split=6
)
assert split == [
[{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}],
[{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}],
]
Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total
number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6,
so input sequence of tokens is split into 2 smaller sequences.
Args:
tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested.
max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number
of permutations which can be generated from input sequence of tokens.
Returns:
:obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split.
"""
splits = []
prev_end_of_split = 0
current_number_of_permutations = 1
for i, token_group in enumerate(tokens):
n = self._estimate_number_of_permutations_in_nested_dict(token_group)
if n * current_number_of_permutations > max_number_of_permutations_per_split:
splits.append(tokens[prev_end_of_split:i])
prev_end_of_split = i
current_number_of_permutations = 1
if n > max_number_of_permutations_per_split:
raise ValueError(
f"Could not split token list with respect to condition that every split can generate number of "
f"permutations less or equal to "
f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. "
f"There is an unsplittable token group that generates more than "
f"{max_number_of_permutations_per_split} permutations. Try to increase "
f"`max_number_of_permutations_per_split` parameter."
)
current_number_of_permutations *= n
splits.append(tokens[prev_end_of_split:])
assert sum([len(s) for s in splits]) == len(tokens)
return splits
def normalize(
self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
verbose: whether to print intermediate meta information
punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
punct_post_process: whether to normalize punctuation
Returns: spoken form
"""
original_text = text
if punct_pre_process:
text = pre_process(text)
text = text.strip()
if not text:
if verbose:
print(text)
return text
text = pynini.escape(text)
tagged_lattice = self.find_tags(text)
tagged_text = self.select_tag(tagged_lattice)
if verbose:
print(tagged_text)
self.parser(tagged_text)
tokens = self.parser.parse()
split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens)
output = ""
for s in split_tokens:
tags_reordered = self.generate_permutations(s)
verbalizer_lattice = None
for tagged_text in tags_reordered:
tagged_text = pynini.escape(tagged_text)
verbalizer_lattice = self.find_verbalizer(tagged_text)
if verbalizer_lattice.num_states() != 0:
break
if verbalizer_lattice is None:
raise ValueError(f"No permutations were generated from tokens {s}")
output += ' ' + self.select_verbalizer(verbalizer_lattice)
output = SPACE_DUP.sub(' ', output[1:])
if self.lang == "en" and hasattr(self, 'post_processor'):
output = self.post_process(output)
if punct_post_process:
# do post-processing based on Moses detokenizer
if self.processor:
output = self.processor.moses_detokenizer.detokenize([output], unescape=False)
output = post_process_punct(input=original_text, normalized_text=output)
else:
print("NEMO_NLP collection is not available: skipping punctuation post_processing")
return output
def split_text_into_sentences(self, text: str) -> List[str]:
"""
Split text into sentences.
Args:
text: text
Returns list of sentences
"""
lower_case_unicode = ''
upper_case_unicode = ''
if self.lang == "ru":
lower_case_unicode = '\u0430-\u04FF'
upper_case_unicode = '\u0410-\u042F'
# Read and split transcript by utterance (roughly, sentences)
split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
sentences = regex.split(split_pattern, text)
return sentences
def _permute(self, d: OrderedDict) -> List[str]:
"""
Creates reorderings of dictionary elements and serializes as strings
Args:
d: (nested) dictionary of key value pairs
Return permutations of different string serializations of key value pairs
"""
l = []
if PRESERVE_ORDER_KEY in d.keys():
d_permutations = [d.items()]
else:
d_permutations = itertools.permutations(d.items())
for perm in d_permutations:
subl = [""]
for k, v in perm:
if isinstance(v, str):
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
elif isinstance(v, OrderedDict):
rec = self._permute(v)
subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
elif isinstance(v, bool):
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
else:
raise ValueError()
l.extend(subl)
return l
def generate_permutations(self, tokens: List[dict]):
"""
Generates permutations of string serializations of list of dictionaries
Args:
tokens: list of dictionaries
Returns string serialization of list of dictionaries
"""
def _helper(prefix: str, tokens: List[dict], idx: int):
"""
Generates permutations of string serializations of given dictionary
Args:
tokens: list of dictionaries
prefix: prefix string
idx: index of next dictionary
Returns string serialization of dictionary
"""
if idx == len(tokens):
yield prefix
return
token_options = self._permute(tokens[idx])
for token_option in token_options:
yield from _helper(prefix + token_option, tokens, idx + 1)
return _helper("", tokens, 0)
def find_tags(self, text: str) -> 'pynini.FstLike':
"""
Given text use tagger Fst to tag text
Args:
text: sentence
Returns: tagged lattice
"""
lattice = text @ self.tagger.fst
return lattice
def select_tag(self, lattice: 'pynini.FstLike') -> str:
"""
Given tagged lattice return shortest path
Args:
tagged_text: tagged text
Returns: shortest path
"""
tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
return tagged_text
def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
"""
Given tagged text creates verbalization lattice
This is context-independent.
Args:
tagged_text: input text
Returns: verbalized lattice
"""
lattice = tagged_text @ self.verbalizer.fst
return lattice
def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
"""
Given verbalized lattice return shortest path
Args:
lattice: verbalization lattice
Returns: shortest path
"""
output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
# lattice = output @ self.verbalizer.punct_graph
# output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
return output
def post_process(self, normalized_text: 'pynini.FstLike') -> str:
"""
Runs post processing graph on normalized text
Args:
normalized_text: normalized text
Returns: shortest path
"""
normalized_text = normalized_text.strip()
if not normalized_text:
return normalized_text
normalized_text = pynini.escape(normalized_text)
if self.post_processor is not None:
normalized_text = top_rewrite(normalized_text, self.post_processor.fst)
return normalized_text
def parse_args():
parser = ArgumentParser()
input = parser.add_mutually_exclusive_group()
input.add_argument("--text", dest="input_string", help="input string", type=str)
input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
parser.add_argument("--language", help="language", choices=["en", "de", "es"], default="en", type=str)
parser.add_argument(
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
)
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
parser.add_argument(
"--punct_post_process",
help="set to True to enable punctuation post processing to match input.",
action="store_true",
)
parser.add_argument(
"--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
)
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
parser.add_argument(
"--cache_dir",
help="path to a dir with .far grammar file. Set to None to avoid using cache",
default=None,
type=str,
)
return parser.parse_args()
if __name__ == "__main__":
start_time = perf_counter()
args = parse_args()
whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
if not args.input_string and not args.input_file:
raise ValueError("Either `--text` or `--input_file` required")
normalizer = Normalizer(
input_case=args.input_case,
cache_dir=args.cache_dir,
overwrite_cache=args.overwrite_cache,
whitelist=whitelist,
lang=args.language,
)
if args.input_string:
print(
normalizer.normalize(
args.input_string,
verbose=args.verbose,
punct_pre_process=args.punct_pre_process,
punct_post_process=args.punct_post_process,
)
)
elif args.input_file:
print("Loading data: " + args.input_file)
data = load_file(args.input_file)
print("- Data: " + str(len(data)) + " sentences")
normalizer_prediction = normalizer.normalize_list(
data,
verbose=args.verbose,
punct_pre_process=args.punct_pre_process,
punct_post_process=args.punct_post_process,
)
if args.output_file:
write_file(args.output_file, normalizer_prediction)
print(f"- Normalized. Writing out to {args.output_file}")
else:
print(normalizer_prediction)
print(f"Execution time: {perf_counter() - start_time:.02f} sec")

View File

@@ -0,0 +1,543 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import time
from argparse import ArgumentParser
from glob import glob
from typing import List, Tuple
import pynini
from joblib import Parallel, delayed
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process
from nemo_text_processing.text_normalization.normalize import Normalizer
from pynini.lib import rewrite
from tqdm import tqdm
try:
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel
ASR_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
ASR_AVAILABLE = False
"""
The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output
(most of the semiotic classes use deterministic=False flag).
To run this script with a .json manifest file, the manifest file should contain the following fields:
"audio_data" - path to the audio file
"text" - raw text
"pred_text" - ASR model prediction
See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions
When the manifest is ready, run:
python normalize_with_audio.py \
--audio_data PATH/TO/MANIFEST.JSON \
--language en
To run with a single audio file, specify path to audio and text with:
python normalize_with_audio.py \
--audio_data PATH/TO/AUDIO.WAV \
--language en \
--text raw text OR PATH/TO/.TXT/FILE
--model QuartzNet15x5Base-En \
--verbose
To see possible normalization options for a text input without an audio file (could be used for debugging), run:
python python normalize_with_audio.py --text "RAW TEXT"
Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference
"""
class NormalizerWithAudio(Normalizer):
"""
Normalizer class that converts text from written to spoken form.
Useful for TTS preprocessing.
Args:
input_case: expected input capitalization
lang: language
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
Note: punct_post_process flag in normalize() supports all languages.
"""
def __init__(
self,
input_case: str,
lang: str = 'en',
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
lm: bool = False,
post_process: bool = True,
):
super().__init__(
input_case=input_case,
lang=lang,
deterministic=False,
cache_dir=cache_dir,
overwrite_cache=overwrite_cache,
whitelist=whitelist,
lm=lm,
post_process=post_process,
)
self.lm = lm
def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
punct_post_process: whether to normalize punctuation
verbose: whether to print intermediate meta information
Returns:
normalized text options (usually there are multiple ways of normalizing a given semiotic class)
"""
if len(text.split()) > 500:
raise ValueError(
"Your input is too long. Please split up the input into sentences, "
"or strings with fewer than 500 words"
)
original_text = text
text = pre_process(text) # to handle []
text = text.strip()
if not text:
if verbose:
print(text)
return text
text = pynini.escape(text)
print(text)
if self.lm:
if self.lang not in ["en"]:
raise ValueError(f"{self.lang} is not supported in LM mode")
if self.lang == "en":
# this to keep arpabet phonemes in the list of options
if "[" in text and "]" in text:
lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
else:
try:
lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits)
except pynini.lib.rewrite.Error:
lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()]
tagged_texts.sort(key=lambda x: x[1])
tagged_texts, weights = list(zip(*tagged_texts))
else:
tagged_texts = self._get_tagged_text(text, n_tagged)
# non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
if self.lang == "en":
normalized_texts = tagged_texts
normalized_texts = [self.post_process(text) for text in normalized_texts]
else:
normalized_texts = []
for tagged_text in tagged_texts:
self._verbalize(tagged_text, normalized_texts, verbose=verbose)
if len(normalized_texts) == 0:
raise ValueError()
if punct_post_process:
# do post-processing based on Moses detokenizer
if self.processor:
normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
normalized_texts = [
post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
]
if self.lm:
remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1])
normalized_texts, weights = zip(*remove_dup)
return list(normalized_texts), weights
normalized_texts = set(normalized_texts)
return normalized_texts
def _get_tagged_text(self, text, n_tagged):
"""
Returns text after tokenize and classify
Args;
text: input text
n_tagged: number of tagged options to consider, -1 - return all possible tagged options
"""
if n_tagged == -1:
if self.lang == "en":
# this to keep arpabet phonemes in the list of options
if "[" in text and "]" in text:
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
else:
try:
tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits)
except pynini.lib.rewrite.Error:
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
else:
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
else:
if self.lang == "en":
# this to keep arpabet phonemes in the list of options
if "[" in text and "]" in text:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
else:
try:
# try self.tagger graph that produces output without digits
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
except pynini.lib.rewrite.Error:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
else:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
return tagged_texts
def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False):
"""
Verbalizes tagged text
Args:
tagged_text: text with tags
normalized_texts: list of possible normalization options
verbose: if true prints intermediate classification results
"""
def get_verbalized_text(tagged_text):
return rewrite.rewrites(tagged_text, self.verbalizer.fst)
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text_reordered in tags_reordered:
try:
tagged_text_reordered = pynini.escape(tagged_text_reordered)
normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
if verbose:
print(tagged_text_reordered)
except pynini.lib.rewrite.Error:
continue
def select_best_match(
self,
normalized_texts: List[str],
input_text: str,
pred_text: str,
verbose: bool = False,
remove_punct: bool = False,
cer_threshold: int = 100,
):
"""
Selects the best normalization option based on the lowest CER
Args:
normalized_texts: normalized text options
input_text: input text
pred_text: ASR model transcript of the audio file corresponding to the normalized text
verbose: whether to print intermediate meta information
remove_punct: whether to remove punctuation before calculating CER
cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed
Returns:
normalized text with the lowest CER and CER value
"""
if pred_text == "":
return input_text, cer_threshold
normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct)
normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1])
normalized_text, cer = normalized_texts_cer[0]
if cer > cer_threshold:
return input_text, cer
if verbose:
print('-' * 30)
for option in normalized_texts:
print(option)
print('-' * 30)
return normalized_text, cer
def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
"""
Calculates character error rate (CER)
Args:
normalized_texts: normalized text options
pred_text: ASR model output
Returns: normalized options with corresponding CER
"""
normalized_options = []
for text in normalized_texts:
text_clean = text.replace('-', ' ').lower()
if remove_punct:
for punct in "!?:;,.-()*+-/<=>@^_":
text_clean = text_clean.replace(punct, "")
cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2)
normalized_options.append((text, cer))
return normalized_options
def get_asr_model(asr_model):
"""
Returns ASR Model
Args:
asr_model: NeMo ASR model
"""
if os.path.exists(args.model):
asr_model = ASRModel.restore_from(asr_model)
elif args.model in ASRModel.get_available_model_names():
asr_model = ASRModel.from_pretrained(asr_model)
else:
raise ValueError(
f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
)
return asr_model
def parse_args():
parser = ArgumentParser()
parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
parser.add_argument(
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
)
parser.add_argument(
"--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
)
parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
parser.add_argument(
'--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
)
parser.add_argument(
"--n_tagged",
type=int,
default=30,
help="number of tagged options to consider, -1 - return all possible tagged options",
)
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
parser.add_argument(
"--no_remove_punct_for_cer",
help="Set to True to NOT remove punctuation before calculating CER",
action="store_true",
)
parser.add_argument(
"--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
)
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
parser.add_argument(
"--cache_dir",
help="path to a dir with .far grammar file. Set to None to avoid using cache",
default=None,
type=str,
)
parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
parser.add_argument(
"--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
)
parser.add_argument(
"--cer_threshold",
default=100,
type=int,
help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
)
parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
return parser.parse_args()
def _normalize_line(
normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
):
line = json.loads(line)
pred_text = line["pred_text"]
normalized_texts = normalizer.normalize(
text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
)
normalized_texts = set(normalized_texts)
normalized_text, cer = normalizer.select_best_match(
normalized_texts=normalized_texts,
input_text=line["text"],
pred_text=pred_text,
verbose=verbose,
remove_punct=remove_punct,
cer_threshold=cer_threshold,
)
line["nemo_normalized"] = normalized_text
line["CER_nemo_normalized"] = cer
return line
def normalize_manifest(
normalizer,
audio_data: str,
n_jobs: int,
n_tagged: int,
remove_punct: bool,
punct_post_process: bool,
batch_size: int,
cer_threshold: int,
):
"""
Args:
args.audio_data: path to .json manifest file.
"""
def __process_batch(batch_idx: int, batch: List[str], dir_name: str):
"""
Normalizes batch of text sequences
Args:
batch: list of texts
batch_idx: batch index
dir_name: path to output directory to save results
"""
normalized_lines = [
_normalize_line(
normalizer,
n_tagged,
verbose=False,
line=line,
remove_punct=remove_punct,
punct_post_process=punct_post_process,
cer_threshold=cer_threshold,
)
for line in tqdm(batch)
]
with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
for line in normalized_lines:
f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
print(f"Batch -- {batch_idx} -- is complete")
manifest_out = audio_data.replace('.json', '_normalized.json')
with open(audio_data, 'r') as f:
lines = f.readlines()
print(f'Normalizing {len(lines)} lines of {audio_data}...')
# to save intermediate results to a file
batch = min(len(lines), batch_size)
tmp_dir = manifest_out.replace(".json", "_parts")
os.makedirs(tmp_dir, exist_ok=True)
Parallel(n_jobs=n_jobs)(
delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir)
for idx, i in enumerate(range(0, len(lines), batch))
)
# aggregate all intermediate files
with open(manifest_out, "w") as f_out:
for batch_f in sorted(glob(f"{tmp_dir}/*.json")):
with open(batch_f, "r") as f_in:
lines = f_in.read()
f_out.write(lines)
print(f'Normalized version saved at {manifest_out}')
if __name__ == "__main__":
args = parse_args()
if not ASR_AVAILABLE and args.audio_data:
raise ValueError("NeMo ASR collection is not installed.")
start = time.time()
args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
if args.text is not None:
normalizer = NormalizerWithAudio(
input_case=args.input_case,
lang=args.language,
cache_dir=args.cache_dir,
overwrite_cache=args.overwrite_cache,
whitelist=args.whitelist,
lm=args.lm,
)
if os.path.exists(args.text):
with open(args.text, 'r') as f:
args.text = f.read().strip()
normalized_texts = normalizer.normalize(
text=args.text,
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
if not normalizer.lm:
normalized_texts = set(normalized_texts)
if args.audio_data:
asr_model = get_asr_model(args.model)
pred_text = asr_model.transcribe([args.audio_data])[0]
normalized_text, cer = normalizer.select_best_match(
normalized_texts=normalized_texts,
pred_text=pred_text,
input_text=args.text,
verbose=args.verbose,
remove_punct=not args.no_remove_punct_for_cer,
cer_threshold=args.cer_threshold,
)
print(f"Transcript: {pred_text}")
print(f"Normalized: {normalized_text}")
else:
print("Normalization options:")
for norm_text in normalized_texts:
print(norm_text)
elif not os.path.exists(args.audio_data):
raise ValueError(f"{args.audio_data} not found.")
elif args.audio_data.endswith('.json'):
normalizer = NormalizerWithAudio(
input_case=args.input_case,
lang=args.language,
cache_dir=args.cache_dir,
overwrite_cache=args.overwrite_cache,
whitelist=args.whitelist,
)
normalize_manifest(
normalizer=normalizer,
audio_data=args.audio_data,
n_jobs=args.n_jobs,
n_tagged=args.n_tagged,
remove_punct=not args.no_remove_punct_for_cer,
punct_post_process=not args.no_punct_post_process,
batch_size=args.batch_size,
cer_threshold=args.cer_threshold,
)
else:
raise ValueError(
"Provide either path to .json manifest in '--audio_data' OR "
+ "'--audio_data' path to audio file and '--text' path to a text file OR"
"'--text' string text (for debugging without audio)"
)
print(f'Execution time: {round((time.time() - start)/60, 2)} min.')

View File

@@ -0,0 +1,117 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from nemo_text_processing.text_normalization.data_loader_utils import (
evaluate,
known_types,
load_files,
training_data_to_sentences,
training_data_to_tokens,
)
from nemo_text_processing.text_normalization.normalize import Normalizer
'''
Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
'''
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input", help="input file path", type=str)
parser.add_argument("--lang", help="language", choices=['en'], default="en", type=str)
parser.add_argument(
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
)
parser.add_argument(
"--cat",
dest="category",
help="focus on class only (" + ", ".join(known_types) + ")",
type=str,
default=None,
choices=known_types,
)
parser.add_argument("--filter", action='store_true', help="clean data for normalization purposes")
return parser.parse_args()
if __name__ == "__main__":
# Example usage:
# python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
args = parse_args()
if args.lang == 'en':
from nemo_text_processing.text_normalization.en.clean_eval_data import filter_loaded_data
file_path = args.input
normalizer = Normalizer(input_case=args.input_case, lang=args.lang)
print("Loading training data: " + file_path)
training_data = load_files([file_path])
if args.filter:
training_data = filter_loaded_data(training_data)
if args.category is None:
print("Sentence level evaluation...")
sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
print("- Data: " + str(len(sentences_normalized)) + " sentences")
sentences_prediction = normalizer.normalize_list(sentences_un_normalized)
print("- Normalized. Evaluating...")
sentences_accuracy = evaluate(
preds=sentences_prediction, labels=sentences_normalized, input=sentences_un_normalized
)
print("- Accuracy: " + str(sentences_accuracy))
print("Token level evaluation...")
tokens_per_type = training_data_to_tokens(training_data, category=args.category)
token_accuracy = {}
for token_type in tokens_per_type:
print("- Token type: " + token_type)
tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
print(" - Data: " + str(len(tokens_normalized)) + " tokens")
tokens_prediction = normalizer.normalize_list(tokens_un_normalized)
print(" - Denormalized. Evaluating...")
token_accuracy[token_type] = evaluate(
preds=tokens_prediction, labels=tokens_normalized, input=tokens_un_normalized
)
print(" - Accuracy: " + str(token_accuracy[token_type]))
token_count_per_type = {token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type}
token_weighted_accuracy = [
token_count_per_type[token_type] * accuracy for token_type, accuracy in token_accuracy.items()
]
print("- Accuracy: " + str(sum(token_weighted_accuracy) / sum(token_count_per_type.values())))
print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
for token_type in token_accuracy:
if token_type not in known_types:
raise ValueError("Unexpected token type: " + token_type)
if args.category is None:
c1 = ['Class', 'sent level'] + known_types
c2 = ['Num Tokens', len(sentences_normalized)] + [
token_count_per_type[known_type] if known_type in tokens_per_type else '0' for known_type in known_types
]
c3 = ['Normalization', sentences_accuracy] + [
token_accuracy[known_type] if known_type in token_accuracy else '0' for known_type in known_types
]
for i in range(len(c1)):
print(f'{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}')
else:
print(f'numbers\t{token_count_per_type[args.category]}')
print(f'Normalization\t{token_accuracy[args.category]}')

View File

@@ -0,0 +1,192 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
from collections import OrderedDict
from typing import Dict, List, Union
PRESERVE_ORDER_KEY = "preserve_order"
EOS = "<EOS>"
class TokenParser:
"""
Parses tokenized/classified text, e.g. 'tokens { money { integer: "20" currency: "$" } } tokens { name: "left"}'
Args
text: tokenized text
"""
def __call__(self, text):
"""
Setup function
Args:
text: text to be parsed
"""
self.text = text
self.len_text = len(text)
self.char = text[0] # cannot handle empty string
self.index = 0
def parse(self) -> List[dict]:
"""
Main function. Implements grammar:
A -> space F space F space F ... space
Returns list of dictionaries
"""
l = list()
while self.parse_ws():
token = self.parse_token()
if not token:
break
l.append(token)
return l
def parse_token(self) -> Dict[str, Union[str, dict]]:
"""
Implements grammar:
F-> no_space KG no_space
Returns: K, G as dictionary values
"""
d = OrderedDict()
key = self.parse_string_key()
if key is None:
return None
self.parse_ws()
if key == PRESERVE_ORDER_KEY:
self.parse_char(":")
self.parse_ws()
value = self.parse_chars("true")
else:
value = self.parse_token_value()
d[key] = value
return d
def parse_token_value(self) -> Union[str, dict]:
"""
Implements grammar:
G-> no_space :"VALUE" no_space | no_space {A} no_space
Returns: string or dictionary
"""
if self.char == ":":
self.parse_char(":")
self.parse_ws()
self.parse_char("\"")
value_string = self.parse_string_value()
self.parse_char("\"")
return value_string
elif self.char == "{":
d = OrderedDict()
self.parse_char("{")
list_token_dicts = self.parse()
# flatten tokens
for tok_dict in list_token_dicts:
for k, v in tok_dict.items():
d[k] = v
self.parse_char("}")
return d
else:
raise ValueError()
def parse_char(self, exp) -> bool:
"""
Parses character
Args:
exp: character to read in
Returns true if successful
"""
assert self.char == exp
self.read()
return True
def parse_chars(self, exp) -> bool:
"""
Parses characters
Args:
exp: characters to read in
Returns true if successful
"""
ok = False
for x in exp:
ok |= self.parse_char(x)
return ok
def parse_string_key(self) -> str:
"""
Parses string key, can only contain ascii and '_' characters
Returns parsed string key
"""
assert self.char not in string.whitespace and self.char != EOS
incl_criterium = string.ascii_letters + "_"
l = []
while self.char in incl_criterium:
l.append(self.char)
if not self.read():
raise ValueError()
if not l:
return None
return "".join(l)
def parse_string_value(self) -> str:
"""
Parses string value, ends with quote followed by space
Returns parsed string value
"""
assert self.char not in string.whitespace and self.char != EOS
l = []
while self.char != "\"" or self.text[self.index + 1] != " ":
l.append(self.char)
if not self.read():
raise ValueError()
if not l:
return None
return "".join(l)
def parse_ws(self):
"""
Deletes whitespaces.
Returns true if not EOS after parsing
"""
not_eos = self.char != EOS
while not_eos and self.char == " ":
not_eos = self.read()
return not_eos
def read(self):
"""
Reads in next char.
Returns true if not EOS
"""
if self.index < self.len_text - 1: # should be unique
self.index += 1
self.char = self.text[self.index]
return True
self.char = EOS
return False