94 lines
2.2 KiB
Python
94 lines
2.2 KiB
Python
"""
|
||
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
||
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
||
test() 是对拍程序
|
||
"""
|
||
|
||
import random
|
||
import string
|
||
from copy import deepcopy
|
||
from typing import List, Tuple
|
||
|
||
import Levenshtein
|
||
|
||
|
||
def mapping(gt: str, dt: str):
|
||
return [i for i in gt], [i for i in dt]
|
||
|
||
|
||
def token_mapping(
|
||
tokens_gt: List[str], tokens_dt: List[str]
|
||
) -> Tuple[List[str], List[str]]:
|
||
arr1 = deepcopy(tokens_gt)
|
||
arr2 = deepcopy(tokens_dt)
|
||
operations = Levenshtein.editops(arr1, arr2)
|
||
for op in operations[::-1]:
|
||
if op[0] == "insert":
|
||
arr1.insert(op[1], None)
|
||
elif op[0] == "delete":
|
||
arr2.insert(op[2], None)
|
||
return arr1, arr2
|
||
|
||
|
||
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||
equal = sum(
|
||
1
|
||
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
||
if token_gt == token_dt
|
||
)
|
||
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
||
return replace, delete, insert
|
||
|
||
|
||
def f(a, b):
|
||
return cer(*token_mapping(*mapping(a, b)))
|
||
|
||
|
||
def raw(tokens_gt, tokens_dt):
|
||
arr1 = deepcopy(tokens_gt)
|
||
arr2 = deepcopy(tokens_dt)
|
||
operations = Levenshtein.editops(arr1, arr2)
|
||
insert = 0
|
||
delete = 0
|
||
replace = 0
|
||
for op in operations:
|
||
if op[0] == "insert":
|
||
insert += 1
|
||
if op[0] == "delete":
|
||
delete += 1
|
||
if op[0] == "replace":
|
||
replace += 1
|
||
return replace, delete, insert
|
||
|
||
|
||
def g(a, b):
|
||
return raw(*mapping(a, b))
|
||
|
||
|
||
def check(a, b):
|
||
ff = f(a, b)
|
||
gg = g(a, b)
|
||
if ff != gg:
|
||
print(ff, gg)
|
||
return ff == gg
|
||
|
||
|
||
def random_string(length):
|
||
letters = string.ascii_lowercase
|
||
return "".join(random.choice(letters) for i in range(length))
|
||
|
||
|
||
def test():
|
||
for _ in range(10000):
|
||
a = random_string(30)
|
||
b = random_string(30)
|
||
if not check(a, b):
|
||
print(a, b)
|
||
break
|
||
|
||
|
||
test()
|