""" f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法 g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法 test() 是对拍程序 """ import random import string from copy import deepcopy from typing import List, Tuple import Levenshtein def mapping(gt: str, dt: str): return [i for i in gt], [i for i in dt] def token_mapping( tokens_gt: List[str], tokens_dt: List[str] ) -> Tuple[List[str], List[str]]: arr1 = deepcopy(tokens_gt) arr2 = deepcopy(tokens_dt) operations = Levenshtein.editops(arr1, arr2) for op in operations[::-1]: if op[0] == "insert": arr1.insert(op[1], None) elif op[0] == "delete": arr2.insert(op[2], None) return arr1, arr2 def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]): """输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt""" insert = sum(1 for item in tokens_gt_mapping if item is None) delete = sum(1 for item in tokens_dt_mapping if item is None) equal = sum( 1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt ) replace = len(tokens_gt_mapping) - insert - equal # - delete return replace, delete, insert def f(a, b): return cer(*token_mapping(*mapping(a, b))) def raw(tokens_gt, tokens_dt): arr1 = deepcopy(tokens_gt) arr2 = deepcopy(tokens_dt) operations = Levenshtein.editops(arr1, arr2) insert = 0 delete = 0 replace = 0 for op in operations: if op[0] == "insert": insert += 1 if op[0] == "delete": delete += 1 if op[0] == "replace": replace += 1 return replace, delete, insert def g(a, b): return raw(*mapping(a, b)) def check(a, b): ff = f(a, b) gg = g(a, b) if ff != gg: print(ff, gg) return ff == gg def random_string(length): letters = string.ascii_lowercase return "".join(random.choice(letters) for i in range(length)) def test(): for _ in range(10000): a = random_string(30) b = random_string(30) if not check(a, b): print(a, b) break test()