Files
enginex-mr_series-asr/utils/speechio/error_rate_en.py
2025-08-20 14:29:42 +08:00

552 lines
19 KiB
Python

#!/usr/bin/env python3
# coding=utf8
# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab)
import argparse
import csv
import json
import logging
import os
import sys
from typing import Iterable
logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s')
import pynini
from pynini.lib import pynutil
# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
# to import original lib:
# from pynini.lib.edit_transducer import EditTransducer
class EditTransducer:
DELETE = "<delete>"
INSERT = "<insert>"
SUBSTITUTE = "<substitute>"
def __init__(
self,
symbol_table,
vocab: Iterable[str],
insert_cost: float = 1.0,
delete_cost: float = 1.0,
substitute_cost: float = 1.0,
bound: int = 0,
):
# Left factor; note that we divide the edit costs by two because they also
# will be incurred when traversing the right factor.
sigma = pynini.union(
*[pynini.accep(token, token_type=symbol_table) for token in vocab],
).optimize()
insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2))
substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2))
edit = pynini.union(insert, delete, substitute).optimize()
if bound:
sigma_star = pynini.closure(sigma)
self._e_i = sigma_star.copy()
for _ in range(bound):
self._e_i.concat(edit.ques).concat(sigma_star)
else:
self._e_i = edit.union(sigma).closure()
self._e_i.optimize()
right_factor_std = EditTransducer._right_factor(self._e_i)
# right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
# e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
right_factor_ext = (
pynini.union(
*[
pynini.cross(
pynini.accep(x, token_type=symbol_table),
pynini.accep(x + '#', token_type=symbol_table),
)
for x in vocab
]
)
.optimize()
.closure()
)
self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()
@staticmethod
def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
ofst = pynini.invert(ifst)
syms = pynini.generated_symbols()
insert_label = syms.find(EditTransducer.INSERT)
delete_label = syms.find(EditTransducer.DELETE)
pairs = [(insert_label, delete_label), (delete_label, insert_label)]
right_factor = ofst.relabel_pairs(ipairs=pairs)
return right_factor
def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
EditTransducer.check_wellformed_lattice(lattice)
return lattice
@staticmethod
def check_wellformed_lattice(lattice: pynini.Fst) -> None:
if lattice.start() == pynini.NO_STATE_ID:
raise RuntimeError("Edit distance composition lattice is empty.")
def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
lattice = self.create_lattice(iexpr, oexpr)
# The shortest cost from all final states to the start state is
# equivalent to the cost of the shortest path.
start = lattice.start()
return float(pynini.shortestdistance(lattice, reverse=True)[start])
def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
print(iexpr)
print(oexpr)
lattice = self.create_lattice(iexpr, oexpr)
alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
return alignment.optimize()
class ErrorStats:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.modified_token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
# return json.dumps(self.__dict__, indent=4)
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_summary(self):
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadKaldiArc(filepath):
utts = {}
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert len(cols) == 2 or len(cols) == 1
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
return utts
def BreakHyphen(token: str):
# 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
# 1. 'T' & 'SHIRT'
# 2. 'TSHIRT'
assert '-' in token
v = token.split('-')
v.append(token.replace('-', ''))
return v
def LoadGLM(rel_path):
'''
glm.csv:
I'VE,I HAVE
GOING TO,GONNA
...
T-SHIRT,T SHIRT,TSHIRT
glm:
{
'<RULE_00000>': ["I'VE", 'I HAVE'],
'<RULE_00001>': ['GOING TO', 'GONNA'],
...
'<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
}
'''
logging.info(f'Loading GLM from {rel_path} ...')
abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=','))
glm = {}
for k, rule in enumerate(reader):
rule_name = f'<RULE_{k:06d}>'
glm[rule_name] = [phrase.strip() for phrase in rule]
logging.info(f' #rule: {len(glm)}')
return glm
def SymbolEQ(symbol_table, i1, i2):
return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')
def PrintSymbolTable(symbol_table: pynini.SymbolTable):
print('SYMBOL_TABLE:')
for k in range(symbol_table.num_symbols()):
sym = symbol_table.find(k)
assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym)
print(k, sym)
print()
def BuildSymbolTable(vocab) -> pynini.SymbolTable:
logging.info('Building symbol table ...')
symbol_table = pynini.SymbolTable()
symbol_table.add_symbol('<epsilon>')
for w in vocab:
symbol_table.add_symbol(w)
logging.info(f' #symbols: {symbol_table.num_symbols()}')
# PrintSymbolTable(symbol_table)
# symbol_table.write_text('symbol_table.txt')
return symbol_table
def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
logging.info('Building GLM tagger ...')
rule_taggers = []
for rule_tag, rule in glm.items():
for phrase in rule:
rule_taggers.append(
(
pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
+ pynini.accep(phrase, token_type=symbol_table)
+ pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
)
)
alphabet = pynini.union(
*[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon
).optimize()
tagger = pynini.cdrewrite(
pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
).optimize() # could be slow with large vocabulary
return tagger
def TokenWidth(token: str):
def CharWidth(c):
return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
return sum([CharWidth(c) for c in token])
def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr):
assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali)
H = ' HYP# : '
R = ' REF : '
E = ' EDIT : '
for i, e in enumerate(edit_ali):
h, r = hyp_ali[i], ref_ali[i]
e = '' if e == 'C' else e # don't bother printing correct edit-tag
nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
n = max(nr, nh, ne) + 1
H += h + ' ' * (n - nh)
R += r + ' ' * (n - nr)
E += e + ' ' * (n - ne)
print(F' HYP : {raw_hyp}', file=stream)
print(H, file=stream)
print(R, file=stream)
print(E, file=stream)
def ComputeTokenErrorRate(c, s, i, d):
assert (s + d + c) != 0
num_edits = s + d + i
ref_len = c + s + d
hyp_len = c + s + i
return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert num_utts != 0
return 100.0 * num_err_utts / num_utts
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logk', type=int, default=500, help='logging interval')
parser.add_argument(
'--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER'
)
parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm')
parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file')
parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
stats = ErrorStats()
logging.info('Generating tokenizer ...')
if args.tokenizer == 'whitespace':
def word_tokenizer(text):
return text.strip().split()
tokenizer = word_tokenizer
elif args.tokenizer == 'char':
def char_tokenizer(text):
return [c for c in text.strip().replace(' ', '')]
tokenizer = char_tokenizer
else:
tokenizer = None
assert tokenizer
logging.info('Loading REF & HYP ...')
ref_utts = LoadKaldiArc(args.ref)
hyp_utts = LoadKaldiArc(args.hyp)
# check valid utterances in hyp that have matched non-empty reference
uids = []
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys():
if ref_utts[uid].text.strip(): # non-empty reference
uids.append(uid)
else:
logging.warning(F'Found {uid} with empty reference, skipping...')
else:
logging.warning(F'Found {uid} without reference, skipping...')
stats.num_hyp_without_ref += 1
stats.num_hyp_utts = len(hyp_utts)
stats.num_ref_utts = len(ref_utts)
stats.num_eval_utts = len(uids)
logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
tokens = []
for uid in uids:
ref_tokens = tokenizer(ref_utts[uid].text)
hyp_tokens = tokenizer(hyp_utts[uid].text)
for t in ref_tokens + hyp_tokens:
tokens.append(t)
if '-' in t:
tokens.extend(BreakHyphen(t))
vocab_from_utts = list(set(tokens))
logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}')
print(f' HYP&REF vocab size: {len(vocab_from_utts)}')
assert args.glm
glm = LoadGLM(args.glm)
tokens = []
for rule in glm.values():
for phrase in rule:
for t in tokenizer(phrase):
tokens.append(t)
if '-' in t:
tokens.extend(BreakHyphen(t))
vocab_from_glm = list(set(tokens))
logging.info(f' GLM vocab size: {len(vocab_from_glm)}')
print(f' GLM vocab size: {len(vocab_from_glm)}')
vocab = list(set(vocab_from_utts + vocab_from_glm))
logging.info(f'Global vocab size: {len(vocab)}')
print(f'Global vocab size: {len(vocab)}')
symtab = BuildSymbolTable(
# Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
vocab
+ [x + '#' for x in vocab]
+ [x for x in glm.keys()]
)
glm_tagger = BuildGLMTagger(glm, symtab)
edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab)
print(edit_transducer)
logging.info('Evaluating error rate ...')
print('Evaluating error rate ...')
fo = open(args.result_file, 'w+', encoding='utf8')
ndone = 0
for uid in uids:
ref = ref_utts[uid].text
raw_hyp = hyp_utts[uid].text
ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab)
print(ref_fst)
# print(ref_fst.string(token_type = symtab))
raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab)
# print(raw_hyp_fst.string(token_type = symtab))
# Say, we have:
# RULE_001: "I'M" <-> "I AM"
# REF: HEY I AM HERE
# HYP: HEY I'M HERE
#
# We want to expand HYP with GLM rules(marked with auxiliary #)
# HYP#: HEY {I'M | I# AM#} HERE
# REF is honored to keep its original form.
#
# This could be considered as a flexible on-the-fly TN towards HYP.
# 1. GLM rule tagging:
# HEY I'M HERE
# ->
# HEY <RULE_001> I'M <RULE_001> HERE
lattice = (raw_hyp_fst @ glm_tagger).optimize()
tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab)
# print(hyp_tagged)
# 2. GLM rule expansion:
# HEY <RULE_001> I'M <RULE_001> HERE
# ->
# sausage-like fst: HEY {I'M | I# AM#} HERE
tokens = tagged_ir.split()
sausage = pynini.accep('', token_type=symtab)
i = 0
while i < len(tokens): # invariant: tokens[0, i) has been built into fst
forms = []
if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'): # rule segment
rule_name = tokens[i]
rule = glm[rule_name]
# pre-condition: i -> ltag
raw_form = ''
for j in range(i + 1, len(tokens)):
if tokens[j] == rule_name:
raw_form = ' '.join(tokens[i + 1 : j])
break
assert raw_form
# post-condition: i -> ltag, j -> rtag
forms.append(raw_form)
for phrase in rule:
if phrase != raw_form:
forms.append(' '.join([x + '#' for x in phrase.split()]))
i = j + 1
else: # normal token segment
token = tokens[i]
forms.append(token)
if "-" in token: # token with hyphen yields extra forms
forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#'
forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
i += 1
sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize()
sausage += sausage_segment
hyp_fst = sausage.optimize()
print(hyp_fst)
# Utterance-Level error rate evaluation
alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)
print("alignment", alignment)
distance = 0.0
C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del
edit_ali, ref_ali, hyp_ali = [], [], []
for state in alignment.states():
for arc in alignment.arcs(state):
i, o = arc.ilabel, arc.olabel
if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
e = 'C'
r, h = symtab.find(i), symtab.find(o)
C += 1
distance += 0.0
elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
e = 'S'
r, h = symtab.find(i), symtab.find(o)
S += 1
distance += 1.0
elif i == 0 and o != 0:
e = 'I'
r, h = '*', symtab.find(o)
I += 1
distance += 1.0
elif i != 0 and o == 0:
e = 'D'
r, h = symtab.find(i), '*'
D += 1
distance += 1.0
else:
raise RuntimeError
edit_ali.append(e)
ref_ali.append(r)
hyp_ali.append(h)
# assert(distance == edit_transducer.compute_distance(ref_fst, sausage))
utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
# print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
# PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)
if utt_ter > 0:
stats.num_utts_with_error += 1
stats.C += C
stats.S += S
stats.I += I
stats.D += D
ndone += 1
if ndone % args.logk == 0:
logging.info(f'{ndone} utts evaluated.')
logging.info(f'{ndone} utts evaluated in total.')
# Corpus-Level evaluation
stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)
print(stats.to_json(), file=fo)
# print(stats.to_kaldi())
# print(stats.to_summary(), file=fo)
fo.close()