update
This commit is contained in:
16
tests/test_callback_editops.py
Normal file
16
tests/test_callback_editops.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import json
|
||||
|
||||
from schemas.dataset import QueryData
|
||||
from schemas.stream import StreamDataModel
|
||||
from utils.evaluator_plus import evaluate_editops
|
||||
|
||||
with open("out/detail_cases.json") as f:
|
||||
detail_cases = json.load(f)
|
||||
|
||||
detail_case = detail_cases[0]
|
||||
preds = []
|
||||
for pred in detail_case["preds"]:
|
||||
preds.append(StreamDataModel.model_validate(pred))
|
||||
label = QueryData.model_validate(detail_case["label"])
|
||||
|
||||
print(evaluate_editops(label, preds))
|
||||
93
tests/test_cer.py
Normal file
93
tests/test_cer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
f(a, b) 计算 a -> b 的编辑距离,使用的方法是之前asr榜单的方法
|
||||
g(a, b) 计算 a -> b 的编辑距离,使用的是原始的编辑距离计算方法
|
||||
test() 是对拍程序
|
||||
"""
|
||||
|
||||
import random
|
||||
import string
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import Levenshtein
|
||||
|
||||
|
||||
def mapping(gt: str, dt: str):
|
||||
return [i for i in gt], [i for i in dt]
|
||||
|
||||
|
||||
def token_mapping(
|
||||
tokens_gt: List[str], tokens_dt: List[str]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
for op in operations[::-1]:
|
||||
if op[0] == "insert":
|
||||
arr1.insert(op[1], None)
|
||||
elif op[0] == "delete":
|
||||
arr2.insert(op[2], None)
|
||||
return arr1, arr2
|
||||
|
||||
|
||||
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||
equal = sum(
|
||||
1
|
||||
for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping)
|
||||
if token_gt == token_dt
|
||||
)
|
||||
replace = len(tokens_gt_mapping) - insert - equal # - delete
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def f(a, b):
|
||||
return cer(*token_mapping(*mapping(a, b)))
|
||||
|
||||
|
||||
def raw(tokens_gt, tokens_dt):
|
||||
arr1 = deepcopy(tokens_gt)
|
||||
arr2 = deepcopy(tokens_dt)
|
||||
operations = Levenshtein.editops(arr1, arr2)
|
||||
insert = 0
|
||||
delete = 0
|
||||
replace = 0
|
||||
for op in operations:
|
||||
if op[0] == "insert":
|
||||
insert += 1
|
||||
if op[0] == "delete":
|
||||
delete += 1
|
||||
if op[0] == "replace":
|
||||
replace += 1
|
||||
return replace, delete, insert
|
||||
|
||||
|
||||
def g(a, b):
|
||||
return raw(*mapping(a, b))
|
||||
|
||||
|
||||
def check(a, b):
|
||||
ff = f(a, b)
|
||||
gg = g(a, b)
|
||||
if ff != gg:
|
||||
print(ff, gg)
|
||||
return ff == gg
|
||||
|
||||
|
||||
def random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
|
||||
def test():
|
||||
for _ in range(10000):
|
||||
a = random_string(30)
|
||||
b = random_string(30)
|
||||
if not check(a, b):
|
||||
print(a, b)
|
||||
break
|
||||
|
||||
|
||||
test()
|
||||
Reference in New Issue
Block a user