371 lines
12 KiB
Python
371 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# coding=utf8
|
|
|
|
# Copyright 2021 Jiayu DU
|
|
|
|
import sys
|
|
import argparse
|
|
import json
|
|
import logging
|
|
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
|
|
|
|
DEBUG = None
|
|
|
|
def GetEditType(ref_token, hyp_token):
|
|
if ref_token == None and hyp_token != None:
|
|
return 'I'
|
|
elif ref_token != None and hyp_token == None:
|
|
return 'D'
|
|
elif ref_token == hyp_token:
|
|
return 'C'
|
|
elif ref_token != hyp_token:
|
|
return 'S'
|
|
else:
|
|
raise RuntimeError
|
|
|
|
class AlignmentArc:
|
|
def __init__(self, src, dst, ref, hyp):
|
|
self.src = src
|
|
self.dst = dst
|
|
self.ref = ref
|
|
self.hyp = hyp
|
|
self.edit_type = GetEditType(ref, hyp)
|
|
|
|
def similarity_score_function(ref_token, hyp_token):
|
|
return 0 if (ref_token == hyp_token) else -1.0
|
|
|
|
def insertion_score_function(token):
|
|
return -1.0
|
|
|
|
def deletion_score_function(token):
|
|
return -1.0
|
|
|
|
def EditDistance(
|
|
ref,
|
|
hyp,
|
|
similarity_score_function = similarity_score_function,
|
|
insertion_score_function = insertion_score_function,
|
|
deletion_score_function = deletion_score_function):
|
|
assert(len(ref) != 0)
|
|
class DPState:
|
|
def __init__(self):
|
|
self.score = -float('inf')
|
|
# backpointer
|
|
self.prev_r = None
|
|
self.prev_h = None
|
|
|
|
def print_search_grid(S, R, H, fstream):
|
|
print(file=fstream)
|
|
for r in range(R):
|
|
for h in range(H):
|
|
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
|
|
print(file=fstream)
|
|
|
|
R = len(ref) + 1
|
|
H = len(hyp) + 1
|
|
|
|
# Construct DP search space, a (R x H) grid
|
|
S = [ [] for r in range(R) ]
|
|
for r in range(R):
|
|
S[r] = [ DPState() for x in range(H) ]
|
|
|
|
# initialize DP search grid origin, S(r = 0, h = 0)
|
|
S[0][0].score = 0.0
|
|
S[0][0].prev_r = None
|
|
S[0][0].prev_h = None
|
|
|
|
# initialize REF axis
|
|
for r in range(1, R):
|
|
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
|
|
S[r][0].prev_r = r-1
|
|
S[r][0].prev_h = 0
|
|
|
|
# initialize HYP axis
|
|
for h in range(1, H):
|
|
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
|
|
S[0][h].prev_r = 0
|
|
S[0][h].prev_h = h-1
|
|
|
|
best_score = S[0][0].score
|
|
best_state = (0, 0)
|
|
|
|
for r in range(1, R):
|
|
for h in range(1, H):
|
|
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
|
|
new_score = S[r-1][h-1].score + sub_or_cor_score
|
|
if new_score >= S[r][h].score:
|
|
S[r][h].score = new_score
|
|
S[r][h].prev_r = r-1
|
|
S[r][h].prev_h = h-1
|
|
|
|
del_score = deletion_score_function(ref[r-1])
|
|
new_score = S[r-1][h].score + del_score
|
|
if new_score >= S[r][h].score:
|
|
S[r][h].score = new_score
|
|
S[r][h].prev_r = r - 1
|
|
S[r][h].prev_h = h
|
|
|
|
ins_score = insertion_score_function(hyp[h-1])
|
|
new_score = S[r][h-1].score + ins_score
|
|
if new_score >= S[r][h].score:
|
|
S[r][h].score = new_score
|
|
S[r][h].prev_r = r
|
|
S[r][h].prev_h = h-1
|
|
|
|
best_score = S[R-1][H-1].score
|
|
best_state = (R-1, H-1)
|
|
|
|
if DEBUG:
|
|
print_search_grid(S, R, H, sys.stderr)
|
|
|
|
# Backtracing best alignment path, i.e. a list of arcs
|
|
# arc = (src, dst, ref, hyp, edit_type)
|
|
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
|
|
best_path = []
|
|
r, h = best_state[0], best_state[1]
|
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
|
score = S[r][h].score
|
|
# loop invariant:
|
|
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
|
|
# 2. score is the value of point(r, h) on DP search grid
|
|
while prev_r != None or prev_h != None:
|
|
src = (prev_r, prev_h)
|
|
dst = (r, h)
|
|
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
|
|
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
|
|
elif (r == prev_r + 1 and h == prev_h): # Deletion
|
|
arc = AlignmentArc(src, dst, ref[prev_r], None)
|
|
elif (r == prev_r and h == prev_h + 1): # Insertion
|
|
arc = AlignmentArc(src, dst, None, hyp[prev_h])
|
|
else:
|
|
raise RuntimeError
|
|
best_path.append(arc)
|
|
r, h = prev_r, prev_h
|
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
|
score = S[r][h].score
|
|
|
|
best_path.reverse()
|
|
return (best_path, best_score)
|
|
|
|
def PrettyPrintAlignment(alignment, stream = sys.stderr):
|
|
def get_token_str(token):
|
|
if token == None:
|
|
return "*"
|
|
return token
|
|
|
|
def is_double_width_char(ch):
|
|
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
|
|
return True
|
|
# TODO: support other double-width-char language such as Japanese, Korean
|
|
else:
|
|
return False
|
|
|
|
def display_width(token_str):
|
|
m = 0
|
|
for c in token_str:
|
|
if is_double_width_char(c):
|
|
m += 2
|
|
else:
|
|
m += 1
|
|
return m
|
|
|
|
R = ' REF : '
|
|
H = ' HYP : '
|
|
E = ' EDIT : '
|
|
for arc in alignment:
|
|
r = get_token_str(arc.ref)
|
|
h = get_token_str(arc.hyp)
|
|
e = arc.edit_type if arc.edit_type != 'C' else ''
|
|
|
|
nr, nh, ne = display_width(r), display_width(h), display_width(e)
|
|
n = max(nr, nh, ne) + 1
|
|
|
|
R += r + ' ' * (n-nr)
|
|
H += h + ' ' * (n-nh)
|
|
E += e + ' ' * (n-ne)
|
|
|
|
print(R, file=stream)
|
|
print(H, file=stream)
|
|
print(E, file=stream)
|
|
|
|
def CountEdits(alignment):
|
|
c, s, i, d = 0, 0, 0, 0
|
|
for arc in alignment:
|
|
if arc.edit_type == 'C':
|
|
c += 1
|
|
elif arc.edit_type == 'S':
|
|
s += 1
|
|
elif arc.edit_type == 'I':
|
|
i += 1
|
|
elif arc.edit_type == 'D':
|
|
d += 1
|
|
else:
|
|
raise RuntimeError
|
|
return (c, s, i, d)
|
|
|
|
def ComputeTokenErrorRate(c, s, i, d):
|
|
return 100.0 * (s + d + i) / (s + d + c)
|
|
|
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
|
assert(num_utts != 0)
|
|
return 100.0 * num_err_utts / num_utts
|
|
|
|
|
|
class EvaluationResult:
|
|
def __init__(self):
|
|
self.num_ref_utts = 0
|
|
self.num_hyp_utts = 0
|
|
self.num_eval_utts = 0 # seen in both ref & hyp
|
|
self.num_hyp_without_ref = 0
|
|
|
|
self.C = 0
|
|
self.S = 0
|
|
self.I = 0
|
|
self.D = 0
|
|
self.token_error_rate = 0.0
|
|
|
|
self.num_utts_with_error = 0
|
|
self.sentence_error_rate = 0.0
|
|
|
|
def to_json(self):
|
|
return json.dumps(self.__dict__)
|
|
|
|
def to_kaldi(self):
|
|
info = (
|
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
|
)
|
|
return info
|
|
|
|
def to_sclite(self):
|
|
return "TODO"
|
|
|
|
def to_espnet(self):
|
|
return "TODO"
|
|
|
|
def to_summary(self):
|
|
#return json.dumps(self.__dict__, indent=4)
|
|
summary = (
|
|
'==================== Overall Statistics ====================\n'
|
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
|
F'token_stats:\n'
|
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
|
F' - cor: {self.C:>7}\n'
|
|
F' - sub: {self.S:>7}\n'
|
|
F' - ins: {self.I:>7}\n'
|
|
F' - del: {self.D:>7}\n'
|
|
'============================================================\n'
|
|
)
|
|
return summary
|
|
|
|
|
|
class Utterance:
|
|
def __init__(self, uid, text):
|
|
self.uid = uid
|
|
self.text = text
|
|
|
|
|
|
def LoadUtterances(filepath, format):
|
|
utts = {}
|
|
if format == 'text': # utt_id word1 word2 ...
|
|
with open(filepath, 'r', encoding='utf8') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
cols = line.split(maxsplit=1)
|
|
assert(len(cols) == 2 or len(cols) == 1)
|
|
uid = cols[0]
|
|
text = cols[1] if len(cols) == 2 else ''
|
|
if utts.get(uid) != None:
|
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
|
utts[uid] = Utterance(uid, text)
|
|
else:
|
|
raise RuntimeError(F'Unsupported text format {format}')
|
|
return utts
|
|
|
|
|
|
def tokenize_text(text, tokenizer):
|
|
if tokenizer == 'whitespace':
|
|
return text.split()
|
|
elif tokenizer == 'char':
|
|
return [ ch for ch in ''.join(text.split()) ]
|
|
else:
|
|
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
# optional
|
|
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
|
|
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
|
|
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
|
|
# required
|
|
parser.add_argument('--ref', type=str, required=True, help='input reference file')
|
|
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
|
|
|
|
parser.add_argument('result_file', type=str)
|
|
args = parser.parse_args()
|
|
logging.info(args)
|
|
|
|
ref_utts = LoadUtterances(args.ref, args.ref_format)
|
|
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
|
|
|
|
r = EvaluationResult()
|
|
|
|
# check valid utterances in hyp that have matched non-empty reference
|
|
eval_utts = []
|
|
r.num_hyp_without_ref = 0
|
|
for uid in sorted(hyp_utts.keys()):
|
|
if uid in ref_utts.keys(): # TODO: efficiency
|
|
if ref_utts[uid].text.strip(): # non-empty reference
|
|
eval_utts.append(uid)
|
|
else:
|
|
logging.warn(F'Found {uid} with empty reference, skipping...')
|
|
else:
|
|
logging.warn(F'Found {uid} without reference, skipping...')
|
|
r.num_hyp_without_ref += 1
|
|
|
|
r.num_hyp_utts = len(hyp_utts)
|
|
r.num_ref_utts = len(ref_utts)
|
|
r.num_eval_utts = len(eval_utts)
|
|
|
|
with open(args.result_file, 'w+', encoding='utf8') as fo:
|
|
for uid in eval_utts:
|
|
ref = ref_utts[uid]
|
|
hyp = hyp_utts[uid]
|
|
|
|
alignment, score = EditDistance(
|
|
tokenize_text(ref.text, args.tokenizer),
|
|
tokenize_text(hyp.text, args.tokenizer)
|
|
)
|
|
|
|
c, s, i, d = CountEdits(alignment)
|
|
utt_ter = ComputeTokenErrorRate(c, s, i, d)
|
|
|
|
# utt-level evaluation result
|
|
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
|
|
PrettyPrintAlignment(alignment, fo)
|
|
|
|
r.C += c
|
|
r.S += s
|
|
r.I += i
|
|
r.D += d
|
|
|
|
if utt_ter > 0:
|
|
r.num_utts_with_error += 1
|
|
|
|
# corpus level evaluation result
|
|
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
|
|
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
|
|
|
|
print(r.to_summary(), file=fo)
|
|
|
|
print(r.to_json())
|
|
print(r.to_kaldi())
|