Files
enginex-bi_series-vc-cnn/tests/test_cer.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

94 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
f(a, b) 计算 a -> b 的编辑距离使用的方法是之前asr榜单的方法
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
test() 是对拍程序
"""
import random
import string
from copy import deepcopy
from typing import List, Tuple
import Levenshtein
def mapping(gt: str, dt: str):
return [i for i in gt], [i for i in dt]
def token_mapping(
tokens_gt: List[str], tokens_dt: List[str]
) -> Tuple[List[str], List[str]]:
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
for op in operations[::-1]:
if op[0] == "insert":
arr1.insert(op[1], None)
elif op[0] == "delete":
arr2.insert(op[2], None)
return arr1, arr2
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
insert = sum(1 for item in tokens_gt_mapping if item is None)
delete = sum(1 for item in tokens_dt_mapping if item is None)
equal = sum(
1
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
if token_gt == token_dt
)
replace = len(tokens_gt_mapping) - insert - equal # - delete
return replace, delete, insert
def f(a, b):
return cer(*token_mapping(*mapping(a, b)))
def raw(tokens_gt, tokens_dt):
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
insert = 0
delete = 0
replace = 0
for op in operations:
if op[0] == "insert":
insert += 1
if op[0] == "delete":
delete += 1
if op[0] == "replace":
replace += 1
return replace, delete, insert
def g(a, b):
return raw(*mapping(a, b))
def check(a, b):
ff = f(a, b)
gg = g(a, b)
if ff != gg:
print(ff, gg)
return ff == gg
def random_string(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
def test():
for _ in range(10000):
a = random_string(30)
b = random_string(30)
if not check(a, b):
print(a, b)
break
test()