58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
# copy from
|
|
# https://gitlab.4pd.io/scene_lab/leaderboard/judge_flows/foundamental_capability/blob/master/utils/asr_ter.py
|
|
|
|
|
|
def calc_ter_speechio(pred, ref, language="zh"):
|
|
assert language == "zh", "Unsupported language %s" % language
|
|
assert ref is not None and ref != "", "Reference script cannot be empty"
|
|
if language == "zh":
|
|
from .speechio import error_rate_zh as error_rate
|
|
from .speechio import textnorm_zh as textnorm
|
|
|
|
normalizer = textnorm.TextNorm(
|
|
to_banjiao=True,
|
|
to_upper=True,
|
|
to_lower=False,
|
|
remove_fillers=True,
|
|
remove_erhua=True,
|
|
check_chars=False,
|
|
remove_space=False,
|
|
cc_mode="",
|
|
)
|
|
norm_pred = normalizer(pred if pred is not None else "")
|
|
norm_ref = normalizer(ref)
|
|
tokenizer = "char"
|
|
alignment, score = error_rate.EditDistance(
|
|
error_rate.tokenize_text(norm_ref, tokenizer),
|
|
error_rate.tokenize_text(norm_pred, tokenizer),
|
|
)
|
|
c, s, i, d = error_rate.CountEdits(alignment)
|
|
ter = error_rate.ComputeTokenErrorRate(c, s, i, d) / 100.0
|
|
return {"ter": ter, "err_token_cnt": s + d + i, "ref_all_token_cnt": s + d + c}
|
|
assert False, "Bug, not reachable"
|
|
|
|
|
|
def calc_ter_wjs(pred, ref, language="zh"):
|
|
assert language == "zh", "Unsupported language %s" % language
|
|
assert ref is not None and ref != "", "Reference script cannot be empty"
|
|
from . import wjs_asr_wer
|
|
|
|
ignore_words = set()
|
|
case_sensitive = False
|
|
split = None
|
|
calculator = wjs_asr_wer.Calculator()
|
|
norm_pred = wjs_asr_wer.normalize(
|
|
wjs_asr_wer.characterize(pred if pred is not None else ""),
|
|
ignore_words,
|
|
case_sensitive,
|
|
split,
|
|
)
|
|
norm_ref = wjs_asr_wer.normalize(wjs_asr_wer.characterize(ref), ignore_words, case_sensitive, split)
|
|
result = calculator.calculate(norm_pred, norm_ref)
|
|
ter = ((result["ins"] + result["sub"] + result["del"]) * 1.0 / result["all"]) if result["all"] != 0 else 1.0
|
|
return {
|
|
"ter": ter,
|
|
"err_token_cnt": result["ins"] + result["sub"] + result["del"],
|
|
"ref_all_token_cnt": result["all"],
|
|
}
|