94 lines
2.2 KiB
Python
94 lines
2.2 KiB
Python
|
|
"""
|
|||
|
|
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
|||
|
|
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
|||
|
|
test() 是对拍程序
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import random
|
|||
|
|
import string
|
|||
|
|
from copy import deepcopy
|
|||
|
|
from typing import List, Tuple
|
|||
|
|
|
|||
|
|
import Levenshtein
|
|||
|
|
|
|||
|
|
|
|||
|
|
def mapping(gt: str, dt: str):
|
|||
|
|
return [i for i in gt], [i for i in dt]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def token_mapping(
|
|||
|
|
tokens_gt: List[str], tokens_dt: List[str]
|
|||
|
|
) -> Tuple[List[str], List[str]]:
|
|||
|
|
arr1 = deepcopy(tokens_gt)
|
|||
|
|
arr2 = deepcopy(tokens_dt)
|
|||
|
|
operations = Levenshtein.editops(arr1, arr2)
|
|||
|
|
for op in operations[::-1]:
|
|||
|
|
if op[0] == "insert":
|
|||
|
|
arr1.insert(op[1], None)
|
|||
|
|
elif op[0] == "delete":
|
|||
|
|
arr2.insert(op[2], None)
|
|||
|
|
return arr1, arr2
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
|||
|
|
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
|||
|
|
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
|||
|
|
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
|||
|
|
equal = sum(
|
|||
|
|
1
|
|||
|
|
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
|||
|
|
if token_gt == token_dt
|
|||
|
|
)
|
|||
|
|
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
|||
|
|
return replace, delete, insert
|
|||
|
|
|
|||
|
|
|
|||
|
|
def f(a, b):
|
|||
|
|
return cer(*token_mapping(*mapping(a, b)))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def raw(tokens_gt, tokens_dt):
|
|||
|
|
arr1 = deepcopy(tokens_gt)
|
|||
|
|
arr2 = deepcopy(tokens_dt)
|
|||
|
|
operations = Levenshtein.editops(arr1, arr2)
|
|||
|
|
insert = 0
|
|||
|
|
delete = 0
|
|||
|
|
replace = 0
|
|||
|
|
for op in operations:
|
|||
|
|
if op[0] == "insert":
|
|||
|
|
insert += 1
|
|||
|
|
if op[0] == "delete":
|
|||
|
|
delete += 1
|
|||
|
|
if op[0] == "replace":
|
|||
|
|
replace += 1
|
|||
|
|
return replace, delete, insert
|
|||
|
|
|
|||
|
|
|
|||
|
|
def g(a, b):
|
|||
|
|
return raw(*mapping(a, b))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def check(a, b):
|
|||
|
|
ff = f(a, b)
|
|||
|
|
gg = g(a, b)
|
|||
|
|
if ff != gg:
|
|||
|
|
print(ff, gg)
|
|||
|
|
return ff == gg
|
|||
|
|
|
|||
|
|
|
|||
|
|
def random_string(length):
|
|||
|
|
letters = string.ascii_lowercase
|
|||
|
|
return "".join(random.choice(letters) for i in range(length))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test():
|
|||
|
|
for _ in range(10000):
|
|||
|
|
a = random_string(30)
|
|||
|
|
b = random_string(30)
|
|||
|
|
if not check(a, b):
|
|||
|
|
print(a, b)
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
|
|||
|
|
test()
|