552 lines
19 KiB
Python
552 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
# coding=utf8
|
|
# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab)
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Iterable
|
|
|
|
logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s')
|
|
|
|
import pynini
|
|
from pynini.lib import pynutil
|
|
|
|
|
|
# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
|
|
# to import original lib:
|
|
# from pynini.lib.edit_transducer import EditTransducer
|
|
class EditTransducer:
|
|
DELETE = "<delete>"
|
|
INSERT = "<insert>"
|
|
SUBSTITUTE = "<substitute>"
|
|
|
|
def __init__(
|
|
self,
|
|
symbol_table,
|
|
vocab: Iterable[str],
|
|
insert_cost: float = 1.0,
|
|
delete_cost: float = 1.0,
|
|
substitute_cost: float = 1.0,
|
|
bound: int = 0,
|
|
):
|
|
# Left factor; note that we divide the edit costs by two because they also
|
|
# will be incurred when traversing the right factor.
|
|
sigma = pynini.union(
|
|
*[pynini.accep(token, token_type=symbol_table) for token in vocab],
|
|
).optimize()
|
|
|
|
insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
|
|
delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2))
|
|
substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2))
|
|
|
|
edit = pynini.union(insert, delete, substitute).optimize()
|
|
|
|
if bound:
|
|
sigma_star = pynini.closure(sigma)
|
|
self._e_i = sigma_star.copy()
|
|
for _ in range(bound):
|
|
self._e_i.concat(edit.ques).concat(sigma_star)
|
|
else:
|
|
self._e_i = edit.union(sigma).closure()
|
|
|
|
self._e_i.optimize()
|
|
|
|
right_factor_std = EditTransducer._right_factor(self._e_i)
|
|
# right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
|
|
# e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
|
|
right_factor_ext = (
|
|
pynini.union(
|
|
*[
|
|
pynini.cross(
|
|
pynini.accep(x, token_type=symbol_table),
|
|
pynini.accep(x + '#', token_type=symbol_table),
|
|
)
|
|
for x in vocab
|
|
]
|
|
)
|
|
.optimize()
|
|
.closure()
|
|
)
|
|
self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()
|
|
|
|
@staticmethod
|
|
def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
|
|
ofst = pynini.invert(ifst)
|
|
syms = pynini.generated_symbols()
|
|
insert_label = syms.find(EditTransducer.INSERT)
|
|
delete_label = syms.find(EditTransducer.DELETE)
|
|
pairs = [(insert_label, delete_label), (delete_label, insert_label)]
|
|
right_factor = ofst.relabel_pairs(ipairs=pairs)
|
|
return right_factor
|
|
|
|
def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
|
|
lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
|
|
EditTransducer.check_wellformed_lattice(lattice)
|
|
return lattice
|
|
|
|
@staticmethod
|
|
def check_wellformed_lattice(lattice: pynini.Fst) -> None:
|
|
if lattice.start() == pynini.NO_STATE_ID:
|
|
raise RuntimeError("Edit distance composition lattice is empty.")
|
|
|
|
def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
|
|
lattice = self.create_lattice(iexpr, oexpr)
|
|
# The shortest cost from all final states to the start state is
|
|
# equivalent to the cost of the shortest path.
|
|
start = lattice.start()
|
|
return float(pynini.shortestdistance(lattice, reverse=True)[start])
|
|
|
|
def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
|
|
print(iexpr)
|
|
print(oexpr)
|
|
lattice = self.create_lattice(iexpr, oexpr)
|
|
alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
|
|
return alignment.optimize()
|
|
|
|
|
|
class ErrorStats:
|
|
def __init__(self):
|
|
self.num_ref_utts = 0
|
|
self.num_hyp_utts = 0
|
|
self.num_eval_utts = 0 # in both ref & hyp
|
|
self.num_hyp_without_ref = 0
|
|
|
|
self.C = 0
|
|
self.S = 0
|
|
self.I = 0
|
|
self.D = 0
|
|
self.token_error_rate = 0.0
|
|
self.modified_token_error_rate = 0.0
|
|
|
|
self.num_utts_with_error = 0
|
|
self.sentence_error_rate = 0.0
|
|
|
|
def to_json(self):
|
|
# return json.dumps(self.__dict__, indent=4)
|
|
return json.dumps(self.__dict__)
|
|
|
|
def to_kaldi(self):
|
|
info = (
|
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
|
)
|
|
return info
|
|
|
|
def to_summary(self):
|
|
summary = (
|
|
'==================== Overall Statistics ====================\n'
|
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
|
F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
|
|
F'token_stats:\n'
|
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
|
F' - cor: {self.C:>7}\n'
|
|
F' - sub: {self.S:>7}\n'
|
|
F' - ins: {self.I:>7}\n'
|
|
F' - del: {self.D:>7}\n'
|
|
'============================================================\n'
|
|
)
|
|
return summary
|
|
|
|
|
|
class Utterance:
|
|
def __init__(self, uid, text):
|
|
self.uid = uid
|
|
self.text = text
|
|
|
|
|
|
def LoadKaldiArc(filepath):
|
|
utts = {}
|
|
with open(filepath, 'r', encoding='utf8') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
cols = line.split(maxsplit=1)
|
|
assert len(cols) == 2 or len(cols) == 1
|
|
uid = cols[0]
|
|
text = cols[1] if len(cols) == 2 else ''
|
|
if utts.get(uid) != None:
|
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
|
utts[uid] = Utterance(uid, text)
|
|
return utts
|
|
|
|
|
|
def BreakHyphen(token: str):
|
|
# 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
|
|
# 1. 'T' & 'SHIRT'
|
|
# 2. 'TSHIRT'
|
|
assert '-' in token
|
|
v = token.split('-')
|
|
v.append(token.replace('-', ''))
|
|
return v
|
|
|
|
|
|
def LoadGLM(rel_path):
|
|
'''
|
|
glm.csv:
|
|
I'VE,I HAVE
|
|
GOING TO,GONNA
|
|
...
|
|
T-SHIRT,T SHIRT,TSHIRT
|
|
|
|
glm:
|
|
{
|
|
'<RULE_00000>': ["I'VE", 'I HAVE'],
|
|
'<RULE_00001>': ['GOING TO', 'GONNA'],
|
|
...
|
|
'<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
|
|
}
|
|
'''
|
|
logging.info(f'Loading GLM from {rel_path} ...')
|
|
|
|
abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
|
|
reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=','))
|
|
|
|
glm = {}
|
|
for k, rule in enumerate(reader):
|
|
rule_name = f'<RULE_{k:06d}>'
|
|
glm[rule_name] = [phrase.strip() for phrase in rule]
|
|
logging.info(f' #rule: {len(glm)}')
|
|
|
|
return glm
|
|
|
|
|
|
def SymbolEQ(symbol_table, i1, i2):
|
|
return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')
|
|
|
|
|
|
def PrintSymbolTable(symbol_table: pynini.SymbolTable):
|
|
print('SYMBOL_TABLE:')
|
|
for k in range(symbol_table.num_symbols()):
|
|
sym = symbol_table.find(k)
|
|
assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym)
|
|
print(k, sym)
|
|
print()
|
|
|
|
|
|
def BuildSymbolTable(vocab) -> pynini.SymbolTable:
|
|
logging.info('Building symbol table ...')
|
|
symbol_table = pynini.SymbolTable()
|
|
symbol_table.add_symbol('<epsilon>')
|
|
|
|
for w in vocab:
|
|
symbol_table.add_symbol(w)
|
|
logging.info(f' #symbols: {symbol_table.num_symbols()}')
|
|
|
|
# PrintSymbolTable(symbol_table)
|
|
# symbol_table.write_text('symbol_table.txt')
|
|
return symbol_table
|
|
|
|
|
|
def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
|
|
logging.info('Building GLM tagger ...')
|
|
rule_taggers = []
|
|
for rule_tag, rule in glm.items():
|
|
for phrase in rule:
|
|
rule_taggers.append(
|
|
(
|
|
pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
|
+ pynini.accep(phrase, token_type=symbol_table)
|
|
+ pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
|
)
|
|
)
|
|
|
|
alphabet = pynini.union(
|
|
*[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon
|
|
).optimize()
|
|
|
|
tagger = pynini.cdrewrite(
|
|
pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
|
|
).optimize() # could be slow with large vocabulary
|
|
return tagger
|
|
|
|
|
|
def TokenWidth(token: str):
|
|
def CharWidth(c):
|
|
return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
|
|
|
|
return sum([CharWidth(c) for c in token])
|
|
|
|
|
|
def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr):
|
|
assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali)
|
|
|
|
H = ' HYP# : '
|
|
R = ' REF : '
|
|
E = ' EDIT : '
|
|
for i, e in enumerate(edit_ali):
|
|
h, r = hyp_ali[i], ref_ali[i]
|
|
e = '' if e == 'C' else e # don't bother printing correct edit-tag
|
|
|
|
nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
|
|
n = max(nr, nh, ne) + 1
|
|
|
|
H += h + ' ' * (n - nh)
|
|
R += r + ' ' * (n - nr)
|
|
E += e + ' ' * (n - ne)
|
|
|
|
print(F' HYP : {raw_hyp}', file=stream)
|
|
print(H, file=stream)
|
|
print(R, file=stream)
|
|
print(E, file=stream)
|
|
|
|
|
|
def ComputeTokenErrorRate(c, s, i, d):
|
|
assert (s + d + c) != 0
|
|
num_edits = s + d + i
|
|
ref_len = c + s + d
|
|
hyp_len = c + s + i
|
|
return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)
|
|
|
|
|
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
|
assert num_utts != 0
|
|
return 100.0 * num_err_utts / num_utts
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--logk', type=int, default=500, help='logging interval')
|
|
parser.add_argument(
|
|
'--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER'
|
|
)
|
|
parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm')
|
|
parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file')
|
|
parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file')
|
|
parser.add_argument('result_file', type=str)
|
|
args = parser.parse_args()
|
|
logging.info(args)
|
|
|
|
stats = ErrorStats()
|
|
|
|
logging.info('Generating tokenizer ...')
|
|
if args.tokenizer == 'whitespace':
|
|
|
|
def word_tokenizer(text):
|
|
return text.strip().split()
|
|
|
|
tokenizer = word_tokenizer
|
|
elif args.tokenizer == 'char':
|
|
|
|
def char_tokenizer(text):
|
|
return [c for c in text.strip().replace(' ', '')]
|
|
|
|
tokenizer = char_tokenizer
|
|
else:
|
|
tokenizer = None
|
|
assert tokenizer
|
|
|
|
logging.info('Loading REF & HYP ...')
|
|
ref_utts = LoadKaldiArc(args.ref)
|
|
hyp_utts = LoadKaldiArc(args.hyp)
|
|
|
|
# check valid utterances in hyp that have matched non-empty reference
|
|
uids = []
|
|
for uid in sorted(hyp_utts.keys()):
|
|
if uid in ref_utts.keys():
|
|
if ref_utts[uid].text.strip(): # non-empty reference
|
|
uids.append(uid)
|
|
else:
|
|
logging.warning(F'Found {uid} with empty reference, skipping...')
|
|
else:
|
|
logging.warning(F'Found {uid} without reference, skipping...')
|
|
stats.num_hyp_without_ref += 1
|
|
|
|
stats.num_hyp_utts = len(hyp_utts)
|
|
stats.num_ref_utts = len(ref_utts)
|
|
stats.num_eval_utts = len(uids)
|
|
logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
|
print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
|
|
|
tokens = []
|
|
for uid in uids:
|
|
ref_tokens = tokenizer(ref_utts[uid].text)
|
|
hyp_tokens = tokenizer(hyp_utts[uid].text)
|
|
for t in ref_tokens + hyp_tokens:
|
|
tokens.append(t)
|
|
if '-' in t:
|
|
tokens.extend(BreakHyphen(t))
|
|
vocab_from_utts = list(set(tokens))
|
|
logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
|
print(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
|
|
|
assert args.glm
|
|
glm = LoadGLM(args.glm)
|
|
|
|
tokens = []
|
|
for rule in glm.values():
|
|
for phrase in rule:
|
|
for t in tokenizer(phrase):
|
|
tokens.append(t)
|
|
if '-' in t:
|
|
tokens.extend(BreakHyphen(t))
|
|
vocab_from_glm = list(set(tokens))
|
|
logging.info(f' GLM vocab size: {len(vocab_from_glm)}')
|
|
print(f' GLM vocab size: {len(vocab_from_glm)}')
|
|
|
|
vocab = list(set(vocab_from_utts + vocab_from_glm))
|
|
logging.info(f'Global vocab size: {len(vocab)}')
|
|
print(f'Global vocab size: {len(vocab)}')
|
|
|
|
symtab = BuildSymbolTable(
|
|
# Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
|
|
vocab
|
|
+ [x + '#' for x in vocab]
|
|
+ [x for x in glm.keys()]
|
|
)
|
|
glm_tagger = BuildGLMTagger(glm, symtab)
|
|
edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab)
|
|
print(edit_transducer)
|
|
|
|
logging.info('Evaluating error rate ...')
|
|
print('Evaluating error rate ...')
|
|
fo = open(args.result_file, 'w+', encoding='utf8')
|
|
ndone = 0
|
|
for uid in uids:
|
|
ref = ref_utts[uid].text
|
|
raw_hyp = hyp_utts[uid].text
|
|
|
|
ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab)
|
|
print(ref_fst)
|
|
|
|
# print(ref_fst.string(token_type = symtab))
|
|
|
|
raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab)
|
|
# print(raw_hyp_fst.string(token_type = symtab))
|
|
|
|
# Say, we have:
|
|
# RULE_001: "I'M" <-> "I AM"
|
|
# REF: HEY I AM HERE
|
|
# HYP: HEY I'M HERE
|
|
#
|
|
# We want to expand HYP with GLM rules(marked with auxiliary #)
|
|
# HYP#: HEY {I'M | I# AM#} HERE
|
|
# REF is honored to keep its original form.
|
|
#
|
|
# This could be considered as a flexible on-the-fly TN towards HYP.
|
|
|
|
# 1. GLM rule tagging:
|
|
# HEY I'M HERE
|
|
# ->
|
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
|
lattice = (raw_hyp_fst @ glm_tagger).optimize()
|
|
tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab)
|
|
# print(hyp_tagged)
|
|
|
|
# 2. GLM rule expansion:
|
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
|
# ->
|
|
# sausage-like fst: HEY {I'M | I# AM#} HERE
|
|
tokens = tagged_ir.split()
|
|
sausage = pynini.accep('', token_type=symtab)
|
|
i = 0
|
|
while i < len(tokens): # invariant: tokens[0, i) has been built into fst
|
|
forms = []
|
|
if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'): # rule segment
|
|
rule_name = tokens[i]
|
|
rule = glm[rule_name]
|
|
# pre-condition: i -> ltag
|
|
raw_form = ''
|
|
for j in range(i + 1, len(tokens)):
|
|
if tokens[j] == rule_name:
|
|
raw_form = ' '.join(tokens[i + 1 : j])
|
|
break
|
|
assert raw_form
|
|
# post-condition: i -> ltag, j -> rtag
|
|
|
|
forms.append(raw_form)
|
|
for phrase in rule:
|
|
if phrase != raw_form:
|
|
forms.append(' '.join([x + '#' for x in phrase.split()]))
|
|
i = j + 1
|
|
else: # normal token segment
|
|
token = tokens[i]
|
|
forms.append(token)
|
|
if "-" in token: # token with hyphen yields extra forms
|
|
forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#'
|
|
forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
|
|
i += 1
|
|
|
|
sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize()
|
|
sausage += sausage_segment
|
|
hyp_fst = sausage.optimize()
|
|
print(hyp_fst)
|
|
|
|
# Utterance-Level error rate evaluation
|
|
alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)
|
|
print("alignment", alignment)
|
|
|
|
distance = 0.0
|
|
C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del
|
|
edit_ali, ref_ali, hyp_ali = [], [], []
|
|
for state in alignment.states():
|
|
for arc in alignment.arcs(state):
|
|
i, o = arc.ilabel, arc.olabel
|
|
if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
|
|
e = 'C'
|
|
r, h = symtab.find(i), symtab.find(o)
|
|
|
|
C += 1
|
|
distance += 0.0
|
|
elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
|
|
e = 'S'
|
|
r, h = symtab.find(i), symtab.find(o)
|
|
|
|
S += 1
|
|
distance += 1.0
|
|
elif i == 0 and o != 0:
|
|
e = 'I'
|
|
r, h = '*', symtab.find(o)
|
|
|
|
I += 1
|
|
distance += 1.0
|
|
elif i != 0 and o == 0:
|
|
e = 'D'
|
|
r, h = symtab.find(i), '*'
|
|
|
|
D += 1
|
|
distance += 1.0
|
|
else:
|
|
raise RuntimeError
|
|
|
|
edit_ali.append(e)
|
|
ref_ali.append(r)
|
|
hyp_ali.append(h)
|
|
# assert(distance == edit_transducer.compute_distance(ref_fst, sausage))
|
|
|
|
utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
|
|
# print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
|
|
# PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)
|
|
|
|
if utt_ter > 0:
|
|
stats.num_utts_with_error += 1
|
|
|
|
stats.C += C
|
|
stats.S += S
|
|
stats.I += I
|
|
stats.D += D
|
|
|
|
ndone += 1
|
|
if ndone % args.logk == 0:
|
|
logging.info(f'{ndone} utts evaluated.')
|
|
logging.info(f'{ndone} utts evaluated in total.')
|
|
|
|
# Corpus-Level evaluation
|
|
stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
|
|
stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)
|
|
|
|
print(stats.to_json(), file=fo)
|
|
# print(stats.to_kaldi())
|
|
# print(stats.to_summary(), file=fo)
|
|
|
|
fo.close()
|