#!/usr/bin/env python3 # coding=utf8 # Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab) import argparse import csv import json import logging import os import sys from typing import Iterable logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s') import pynini from pynini.lib import pynutil # reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py # to import original lib: # from pynini.lib.edit_transducer import EditTransducer class EditTransducer: DELETE = "" INSERT = "" SUBSTITUTE = "" def __init__( self, symbol_table, vocab: Iterable[str], insert_cost: float = 1.0, delete_cost: float = 1.0, substitute_cost: float = 1.0, bound: int = 0, ): # Left factor; note that we divide the edit costs by two because they also # will be incurred when traversing the right factor. sigma = pynini.union( *[pynini.accep(token, token_type=symbol_table) for token in vocab], ).optimize() insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2) delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2)) substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2)) edit = pynini.union(insert, delete, substitute).optimize() if bound: sigma_star = pynini.closure(sigma) self._e_i = sigma_star.copy() for _ in range(bound): self._e_i.concat(edit.ques).concat(sigma_star) else: self._e_i = edit.union(sigma).closure() self._e_i.optimize() right_factor_std = EditTransducer._right_factor(self._e_i) # right_factor_ext allows 0-cost matching between token's raw form & auxiliary form # e.g.: 'I' -> 'I#', 'AM' -> 'AM#' right_factor_ext = ( pynini.union( *[ pynini.cross( pynini.accep(x, token_type=symbol_table), pynini.accep(x + '#', token_type=symbol_table), ) for x in vocab ] ) .optimize() .closure() ) self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize() @staticmethod def _right_factor(ifst: pynini.Fst) -> pynini.Fst: ofst = pynini.invert(ifst) syms = pynini.generated_symbols() insert_label = syms.find(EditTransducer.INSERT) delete_label = syms.find(EditTransducer.DELETE) pairs = [(insert_label, delete_label), (delete_label, insert_label)] right_factor = ofst.relabel_pairs(ipairs=pairs) return right_factor def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst: lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr) EditTransducer.check_wellformed_lattice(lattice) return lattice @staticmethod def check_wellformed_lattice(lattice: pynini.Fst) -> None: if lattice.start() == pynini.NO_STATE_ID: raise RuntimeError("Edit distance composition lattice is empty.") def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float: lattice = self.create_lattice(iexpr, oexpr) # The shortest cost from all final states to the start state is # equivalent to the cost of the shortest path. start = lattice.start() return float(pynini.shortestdistance(lattice, reverse=True)[start]) def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike: print(iexpr) print(oexpr) lattice = self.create_lattice(iexpr, oexpr) alignment = pynini.shortestpath(lattice, nshortest=1, unique=True) return alignment.optimize() class ErrorStats: def __init__(self): self.num_ref_utts = 0 self.num_hyp_utts = 0 self.num_eval_utts = 0 # in both ref & hyp self.num_hyp_without_ref = 0 self.C = 0 self.S = 0 self.I = 0 self.D = 0 self.token_error_rate = 0.0 self.modified_token_error_rate = 0.0 self.num_utts_with_error = 0 self.sentence_error_rate = 0.0 def to_json(self): # return json.dumps(self.__dict__, indent=4) return json.dumps(self.__dict__) def to_kaldi(self): info = ( F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n' F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n' ) return info def to_summary(self): summary = ( '==================== Overall Statistics ====================\n' F'num_ref_utts: {self.num_ref_utts}\n' F'num_hyp_utts: {self.num_hyp_utts}\n' F'num_hyp_without_ref: {self.num_hyp_without_ref}\n' F'num_eval_utts: {self.num_eval_utts}\n' F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n' F'token_error_rate: {self.token_error_rate:.2f}%\n' F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n' F'token_stats:\n' F' - tokens:{self.C + self.S + self.D:>7}\n' F' - edits: {self.S + self.I + self.D:>7}\n' F' - cor: {self.C:>7}\n' F' - sub: {self.S:>7}\n' F' - ins: {self.I:>7}\n' F' - del: {self.D:>7}\n' '============================================================\n' ) return summary class Utterance: def __init__(self, uid, text): self.uid = uid self.text = text def LoadKaldiArc(filepath): utts = {} with open(filepath, 'r', encoding='utf8') as f: for line in f: line = line.strip() if line: cols = line.split(maxsplit=1) assert len(cols) == 2 or len(cols) == 1 uid = cols[0] text = cols[1] if len(cols) == 2 else '' if utts.get(uid) != None: raise RuntimeError(F'Found duplicated utterence id {uid}') utts[uid] = Utterance(uid, text) return utts def BreakHyphen(token: str): # 'T-SHIRT' should also introduce new words into vocabulary, e.g.: # 1. 'T' & 'SHIRT' # 2. 'TSHIRT' assert '-' in token v = token.split('-') v.append(token.replace('-', '')) return v def LoadGLM(rel_path): ''' glm.csv: I'VE,I HAVE GOING TO,GONNA ... T-SHIRT,T SHIRT,TSHIRT glm: { '': ["I'VE", 'I HAVE'], '': ['GOING TO', 'GONNA'], ... '': ['T-SHIRT', 'T SHIRT', 'TSHIRT'], } ''' logging.info(f'Loading GLM from {rel_path} ...') abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=',')) glm = {} for k, rule in enumerate(reader): rule_name = f'' glm[rule_name] = [phrase.strip() for phrase in rule] logging.info(f' #rule: {len(glm)}') return glm def SymbolEQ(symbol_table, i1, i2): return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#') def PrintSymbolTable(symbol_table: pynini.SymbolTable): print('SYMBOL_TABLE:') for k in range(symbol_table.num_symbols()): sym = symbol_table.find(k) assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym) print(k, sym) print() def BuildSymbolTable(vocab) -> pynini.SymbolTable: logging.info('Building symbol table ...') symbol_table = pynini.SymbolTable() symbol_table.add_symbol('') for w in vocab: symbol_table.add_symbol(w) logging.info(f' #symbols: {symbol_table.num_symbols()}') # PrintSymbolTable(symbol_table) # symbol_table.write_text('symbol_table.txt') return symbol_table def BuildGLMTagger(glm, symbol_table) -> pynini.Fst: logging.info('Building GLM tagger ...') rule_taggers = [] for rule_tag, rule in glm.items(): for phrase in rule: rule_taggers.append( ( pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table)) + pynini.accep(phrase, token_type=symbol_table) + pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table)) ) ) alphabet = pynini.union( *[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon ).optimize() tagger = pynini.cdrewrite( pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure() ).optimize() # could be slow with large vocabulary return tagger def TokenWidth(token: str): def CharWidth(c): return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1 return sum([CharWidth(c) for c in token]) def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr): assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali) H = ' HYP# : ' R = ' REF : ' E = ' EDIT : ' for i, e in enumerate(edit_ali): h, r = hyp_ali[i], ref_ali[i] e = '' if e == 'C' else e # don't bother printing correct edit-tag nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e) n = max(nr, nh, ne) + 1 H += h + ' ' * (n - nh) R += r + ' ' * (n - nr) E += e + ' ' * (n - ne) print(F' HYP : {raw_hyp}', file=stream) print(H, file=stream) print(R, file=stream) print(E, file=stream) def ComputeTokenErrorRate(c, s, i, d): assert (s + d + c) != 0 num_edits = s + d + i ref_len = c + s + d hyp_len = c + s + i return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len) def ComputeSentenceErrorRate(num_err_utts, num_utts): assert num_utts != 0 return 100.0 * num_err_utts / num_utts if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--logk', type=int, default=500, help='logging interval') parser.add_argument( '--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER' ) parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm') parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file') parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file') parser.add_argument('result_file', type=str) args = parser.parse_args() logging.info(args) stats = ErrorStats() logging.info('Generating tokenizer ...') if args.tokenizer == 'whitespace': def word_tokenizer(text): return text.strip().split() tokenizer = word_tokenizer elif args.tokenizer == 'char': def char_tokenizer(text): return [c for c in text.strip().replace(' ', '')] tokenizer = char_tokenizer else: tokenizer = None assert tokenizer logging.info('Loading REF & HYP ...') ref_utts = LoadKaldiArc(args.ref) hyp_utts = LoadKaldiArc(args.hyp) # check valid utterances in hyp that have matched non-empty reference uids = [] for uid in sorted(hyp_utts.keys()): if uid in ref_utts.keys(): if ref_utts[uid].text.strip(): # non-empty reference uids.append(uid) else: logging.warning(F'Found {uid} with empty reference, skipping...') else: logging.warning(F'Found {uid} without reference, skipping...') stats.num_hyp_without_ref += 1 stats.num_hyp_utts = len(hyp_utts) stats.num_ref_utts = len(ref_utts) stats.num_eval_utts = len(uids) logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}') print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}') tokens = [] for uid in uids: ref_tokens = tokenizer(ref_utts[uid].text) hyp_tokens = tokenizer(hyp_utts[uid].text) for t in ref_tokens + hyp_tokens: tokens.append(t) if '-' in t: tokens.extend(BreakHyphen(t)) vocab_from_utts = list(set(tokens)) logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}') print(f' HYP&REF vocab size: {len(vocab_from_utts)}') assert args.glm glm = LoadGLM(args.glm) tokens = [] for rule in glm.values(): for phrase in rule: for t in tokenizer(phrase): tokens.append(t) if '-' in t: tokens.extend(BreakHyphen(t)) vocab_from_glm = list(set(tokens)) logging.info(f' GLM vocab size: {len(vocab_from_glm)}') print(f' GLM vocab size: {len(vocab_from_glm)}') vocab = list(set(vocab_from_utts + vocab_from_glm)) logging.info(f'Global vocab size: {len(vocab)}') print(f'Global vocab size: {len(vocab)}') symtab = BuildSymbolTable( # Normal evaluation vocab + auxiliary form for alternative paths + GLM tags vocab + [x + '#' for x in vocab] + [x for x in glm.keys()] ) glm_tagger = BuildGLMTagger(glm, symtab) edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab) print(edit_transducer) logging.info('Evaluating error rate ...') print('Evaluating error rate ...') fo = open(args.result_file, 'w+', encoding='utf8') ndone = 0 for uid in uids: ref = ref_utts[uid].text raw_hyp = hyp_utts[uid].text ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab) print(ref_fst) # print(ref_fst.string(token_type = symtab)) raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab) # print(raw_hyp_fst.string(token_type = symtab)) # Say, we have: # RULE_001: "I'M" <-> "I AM" # REF: HEY I AM HERE # HYP: HEY I'M HERE # # We want to expand HYP with GLM rules(marked with auxiliary #) # HYP#: HEY {I'M | I# AM#} HERE # REF is honored to keep its original form. # # This could be considered as a flexible on-the-fly TN towards HYP. # 1. GLM rule tagging: # HEY I'M HERE # -> # HEY I'M HERE lattice = (raw_hyp_fst @ glm_tagger).optimize() tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab) # print(hyp_tagged) # 2. GLM rule expansion: # HEY I'M HERE # -> # sausage-like fst: HEY {I'M | I# AM#} HERE tokens = tagged_ir.split() sausage = pynini.accep('', token_type=symtab) i = 0 while i < len(tokens): # invariant: tokens[0, i) has been built into fst forms = [] if tokens[i].startswith(''): # rule segment rule_name = tokens[i] rule = glm[rule_name] # pre-condition: i -> ltag raw_form = '' for j in range(i + 1, len(tokens)): if tokens[j] == rule_name: raw_form = ' '.join(tokens[i + 1 : j]) break assert raw_form # post-condition: i -> ltag, j -> rtag forms.append(raw_form) for phrase in rule: if phrase != raw_form: forms.append(' '.join([x + '#' for x in phrase.split()])) i = j + 1 else: # normal token segment token = tokens[i] forms.append(token) if "-" in token: # token with hyphen yields extra forms forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#' forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#' i += 1 sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize() sausage += sausage_segment hyp_fst = sausage.optimize() print(hyp_fst) # Utterance-Level error rate evaluation alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst) print("alignment", alignment) distance = 0.0 C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del edit_ali, ref_ali, hyp_ali = [], [], [] for state in alignment.states(): for arc in alignment.arcs(state): i, o = arc.ilabel, arc.olabel if i != 0 and o != 0 and SymbolEQ(symtab, i, o): e = 'C' r, h = symtab.find(i), symtab.find(o) C += 1 distance += 0.0 elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o): e = 'S' r, h = symtab.find(i), symtab.find(o) S += 1 distance += 1.0 elif i == 0 and o != 0: e = 'I' r, h = '*', symtab.find(o) I += 1 distance += 1.0 elif i != 0 and o == 0: e = 'D' r, h = symtab.find(i), '*' D += 1 distance += 1.0 else: raise RuntimeError edit_ali.append(e) ref_ali.append(r) hyp_ali.append(h) # assert(distance == edit_transducer.compute_distance(ref_fst, sausage)) utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D) # print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo) # PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo) if utt_ter > 0: stats.num_utts_with_error += 1 stats.C += C stats.S += S stats.I += I stats.D += D ndone += 1 if ndone % args.logk == 0: logging.info(f'{ndone} utts evaluated.') logging.info(f'{ndone} utts evaluated in total.') # Corpus-Level evaluation stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D) stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts) print(stats.to_json(), file=fo) # print(stats.to_kaldi()) # print(stats.to_summary(), file=fo) fo.close()