This commit is contained in:
zhousha
2025-08-06 15:38:55 +08:00
parent 4916ad0fe0
commit 55a67e817e
193 changed files with 51647 additions and 1 deletions

View File

@@ -0,0 +1,16 @@
import json
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops
with open("out/detail_cases.json") as f:
detail_cases = json.load(f)
detail_case = detail_cases[0]
preds = []
for pred in detail_case["preds"]:
preds.append(StreamDataModel.model_validate(pred))
label = QueryData.model_validate(detail_case["label"])
print(evaluate_editops(label, preds))

93
tests/test_cer.py Normal file
View File

@@ -0,0 +1,93 @@
"""
f(a, b) 计算 a -> b 的编辑距离使用的方法是之前asr榜单的方法
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
test() 是对拍程序
"""
import random
import string
from copy import deepcopy
from typing import List, Tuple
import Levenshtein
def mapping(gt: str, dt: str):
return [i for i in gt], [i for i in dt]
def token_mapping(
tokens_gt: List[str], tokens_dt: List[str]
) -> Tuple[List[str], List[str]]:
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
for op in operations[::-1]:
if op[0] == "insert":
arr1.insert(op[1], None)
elif op[0] == "delete":
arr2.insert(op[2], None)
return arr1, arr2
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
insert = sum(1 for item in tokens_gt_mapping if item is None)
delete = sum(1 for item in tokens_dt_mapping if item is None)
equal = sum(
1
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
if token_gt == token_dt
)
replace = len(tokens_gt_mapping) - insert - equal # - delete
return replace, delete, insert
def f(a, b):
return cer(*token_mapping(*mapping(a, b)))
def raw(tokens_gt, tokens_dt):
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
insert = 0
delete = 0
replace = 0
for op in operations:
if op[0] == "insert":
insert += 1
if op[0] == "delete":
delete += 1
if op[0] == "replace":
replace += 1
return replace, delete, insert
def g(a, b):
return raw(*mapping(a, b))
def check(a, b):
ff = f(a, b)
gg = g(a, b)
if ff != gg:
print(ff, gg)
return ff == gg
def random_string(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
def test():
for _ in range(10000):
a = random_string(30)
b = random_string(30)
if not check(a, b):
print(a, b)
break
test()