update
This commit is contained in:
93
tests/test_cer.py
Normal file
93
tests/test_cer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
||||
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
||||
test() 是对拍程序
|
||||
"""
|
||||
|
||||
import random
|
||||
import string
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import Levenshtein
|
||||
|
||||
|
||||
def mapping(gt: str, dt: str):
|
||||
return [i for i in gt], [i for i in dt]
|
||||
|
||||
|
||||
def token_mapping(
|
||||
tokens_gt: List[str], tokens_dt: List[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
for op in operations[::-1]:
|
||||
if op[0] == "insert":
|
||||
arr1.insert(op[1], None)
|
||||
elif op[0] == "delete":
|
||||
arr2.insert(op[2], None)
|
||||
return arr1, arr2
|
||||
|
||||
|
||||
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||
equal = sum(
|
||||
1
|
||||
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
||||
if token_gt == token_dt
|
||||
)
|
||||
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def f(a, b):
|
||||
return cer(*token_mapping(*mapping(a, b)))
|
||||
|
||||
|
||||
def raw(tokens_gt, tokens_dt):
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
insert = 0
|
||||
delete = 0
|
||||
replace = 0
|
||||
for op in operations:
|
||||
if op[0] == "insert":
|
||||
insert += 1
|
||||
if op[0] == "delete":
|
||||
delete += 1
|
||||
if op[0] == "replace":
|
||||
replace += 1
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def g(a, b):
|
||||
return raw(*mapping(a, b))
|
||||
|
||||
|
||||
def check(a, b):
|
||||
ff = f(a, b)
|
||||
gg = g(a, b)
|
||||
if ff != gg:
|
||||
print(ff, gg)
|
||||
return ff == gg
|
||||
|
||||
|
||||
def random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
|
||||
def test():
|
||||
for _ in range(10000):
|
||||
a = random_string(30)
|
||||
b = random_string(30)
|
||||
if not check(a, b):
|
||||
print(a, b)
|
||||
break
|
||||
|
||||
|
||||
test()
|
||||
Reference in New Issue
Block a user