Files
enginex-bi_series-vc-cnn/utils/evaluator.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

196 lines
7.7 KiB
Python

# coding: utf-8
import os
from collections import Counter, defaultdict
from itertools import chain
from typing import List
from schemas.context import ASRContext
from utils.logger import logger
from utils.metrics import cer, cut_rate, cut_sentence, first_delay
from utils.metrics import mean_on_counter, patch_unique_token_count
from utils.metrics import revision_delay, text_align, token_mapping
from utils.metrics import var_on_counter
from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer
from utils.update_submit import change_product_available
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None
class BaseEvaluator:
def __init__(self) -> None:
self.query_count = 0 # query 数目(语音数目)
self.voice_count = 0
self.fail_count = 0 # 失败数目
# 首字延迟
self.first_delay_sum = 0
self.first_delay_cnt = 0
# 修正延迟
self.revision_delay_sum = 0
self.revision_delay_cnt = 0
# patch token 信息
self.patch_unique_cnt_counter = Counter()
# text align count
self.start_time_align_count = 0
self.end_time_align_count = 0
self.start_end_count = 0
# 1-cer
self.one_minus_cer = 0
self.token_count = 0
# 1-cer language
self.one_minus_cer_lang = defaultdict(int)
self.query_count_lang = defaultdict(int)
# sentence-cut
self.miss_count = 0
self.more_count = 0
self.sentence_count = 0
self.cut_rate = 0
# detail-case
self.context = ASRContext()
# 时延
self.send_interval = []
self.last_recv_interval = []
# 字含量不达标数
self.fail_char_contains_rate_num = 0
# 标点符号
self.punctuation_num = 0
self.pred_punctuation_num = 0
def evaluate(self, context: ASRContext):
self.query_count += 1
self.query_count_lang[context.lang] += 1
voice_count = len(context.labels)
self.voice_count += voice_count
self.punctuation_num += context.punctuation_num
self.pred_punctuation_num += context.pred_punctuation_num
if not context.fail:
# 首字延迟
first_delay_sum, first_delay_cnt = first_delay(context)
self.first_delay_sum += first_delay_sum
self.first_delay_cnt += first_delay_cnt
# 修正延迟
revision_delay_sum, revision_delay_cnt = revision_delay(context)
self.revision_delay_sum += revision_delay_sum
self.revision_delay_cnt += revision_delay_cnt
# patch token 信息
counter = patch_unique_token_count(context)
self.patch_unique_cnt_counter += counter
else:
self.fail_count += 1
self.fail_char_contains_rate_num += context.fail_char_contains_rate_num
# text align count
start_time_align_count, end_time_align_count, start_end_count = text_align(context)
self.start_time_align_count += start_time_align_count
self.end_time_align_count += end_time_align_count
self.start_end_count += start_end_count
# cer, wer
sentences_gt: List[str] = [item.answer for item in context.labels]
sentences_dt: List[str] = [
item.recognition_results.text for item in context.preds if item.recognition_results.final_result
]
if IN_TEST:
print(sentences_gt)
print(sentences_dt)
sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang))
sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang))
if IN_TEST:
print(sentences_gt)
print(sentences_dt)
# norm & tokenize
tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang)
tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang)
if IN_TEST:
print(tokens_gt)
print(tokens_dt)
# cer
tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt)))
one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping)
self.one_minus_cer += one_minue_cer
self.token_count += token_count
self.one_minus_cer_lang[context.lang] += one_minue_cer
# cut-rate
rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping)
self.cut_rate += rate
self.sentence_count += sentence_cnt
self.miss_count += miss_cnt
self.more_count += more_cnt
# detail-case
self.context = context
# 时延
if self.context.send_time_start_end and self.context.recv_time_start_end:
send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0]
recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0]
self.send_interval.append(send_interval)
self.last_recv_interval.append(recv_interval)
logger.info(
f"""第一次发送时间{self.context.send_time_start_end[0]}, \
最后一次发送时间{self.context.send_time_start_end[-1]}, \
发送间隔 {send_interval},
最后一次接收时间{self.context.recv_time_start_end[-1]}, \
接收间隔 {recv_interval}
"""
)
def post_evaluate(self):
pass
def gen_result(self):
result = {
"query_count": self.query_count,
"voice_count": self.voice_count,
"pred_voice_count": self.first_delay_cnt,
"first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10,
"revision_delay_mean": (
self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10
),
"patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter),
"patch_token_var": var_on_counter(self.patch_unique_cnt_counter),
"start_time_align_count": self.start_time_align_count,
"end_time_align_count": self.end_time_align_count,
"start_time_align_rate": self.start_time_align_count / self.sentence_count,
"end_time_align_rate": self.end_time_align_count / self.sentence_count,
"start_end_count": self.start_end_count,
"one_minus_cer": self.one_minus_cer / self.query_count,
"token_count": self.token_count,
"miss_count": self.miss_count,
"more_count": self.more_count,
"sentence_count": self.sentence_count,
"cut_rate": self.cut_rate / self.query_count,
"fail_count": self.fail_count,
"send_interval": self.send_interval,
"last_recv_interval": self.last_recv_interval,
"fail_char_contains_rate_num": self.fail_char_contains_rate_num,
"punctuation_rate": self.pred_punctuation_num / self.punctuation_num,
}
for lang in self.one_minus_cer_lang:
result["one_minus_cer_" + lang] = \
self.one_minus_cer_lang[lang] / self.query_count_lang[lang]
if (
result["first_delay_mean"]
> float(os.getenv("FIRST_DELAY_THRESHOLD", "5"))
or
self.fail_char_contains_rate_num / self.voice_count > 0.1
# or
# result["punctuation_rate"] < 0.8
):
change_product_available()
return result
def gen_detail_case(self):
return self.context