update
This commit is contained in:
195
utils/evaluator.py
Normal file
195
utils/evaluator.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
from collections import Counter, defaultdict
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from schemas.context import ASRContext
|
||||
from utils.logger import logger
|
||||
from utils.metrics import cer, cut_rate, cut_sentence, first_delay
|
||||
from utils.metrics import mean_on_counter, patch_unique_token_count
|
||||
from utils.metrics import revision_delay, text_align, token_mapping
|
||||
from utils.metrics import var_on_counter
|
||||
from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer
|
||||
from utils.update_submit import change_product_available
|
||||
|
||||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
def __init__(self) -> None:
|
||||
self.query_count = 0 # query 数目(语音数目)
|
||||
self.voice_count = 0
|
||||
self.fail_count = 0 # 失败数目
|
||||
# 首字延迟
|
||||
self.first_delay_sum = 0
|
||||
self.first_delay_cnt = 0
|
||||
# 修正延迟
|
||||
self.revision_delay_sum = 0
|
||||
self.revision_delay_cnt = 0
|
||||
# patch token 信息
|
||||
self.patch_unique_cnt_counter = Counter()
|
||||
# text align count
|
||||
self.start_time_align_count = 0
|
||||
self.end_time_align_count = 0
|
||||
self.start_end_count = 0
|
||||
# 1-cer
|
||||
self.one_minus_cer = 0
|
||||
self.token_count = 0
|
||||
# 1-cer language
|
||||
self.one_minus_cer_lang = defaultdict(int)
|
||||
self.query_count_lang = defaultdict(int)
|
||||
# sentence-cut
|
||||
self.miss_count = 0
|
||||
self.more_count = 0
|
||||
self.sentence_count = 0
|
||||
self.cut_rate = 0
|
||||
# detail-case
|
||||
self.context = ASRContext()
|
||||
# 时延
|
||||
self.send_interval = []
|
||||
self.last_recv_interval = []
|
||||
# 字含量不达标数
|
||||
self.fail_char_contains_rate_num = 0
|
||||
# 标点符号
|
||||
self.punctuation_num = 0
|
||||
self.pred_punctuation_num = 0
|
||||
|
||||
def evaluate(self, context: ASRContext):
|
||||
self.query_count += 1
|
||||
self.query_count_lang[context.lang] += 1
|
||||
|
||||
voice_count = len(context.labels)
|
||||
self.voice_count += voice_count
|
||||
|
||||
self.punctuation_num += context.punctuation_num
|
||||
self.pred_punctuation_num += context.pred_punctuation_num
|
||||
|
||||
if not context.fail:
|
||||
# 首字延迟
|
||||
first_delay_sum, first_delay_cnt = first_delay(context)
|
||||
self.first_delay_sum += first_delay_sum
|
||||
self.first_delay_cnt += first_delay_cnt
|
||||
|
||||
# 修正延迟
|
||||
revision_delay_sum, revision_delay_cnt = revision_delay(context)
|
||||
self.revision_delay_sum += revision_delay_sum
|
||||
self.revision_delay_cnt += revision_delay_cnt
|
||||
|
||||
# patch token 信息
|
||||
counter = patch_unique_token_count(context)
|
||||
self.patch_unique_cnt_counter += counter
|
||||
else:
|
||||
self.fail_count += 1
|
||||
|
||||
self.fail_char_contains_rate_num += context.fail_char_contains_rate_num
|
||||
|
||||
# text align count
|
||||
start_time_align_count, end_time_align_count, start_end_count = text_align(context)
|
||||
self.start_time_align_count += start_time_align_count
|
||||
self.end_time_align_count += end_time_align_count
|
||||
self.start_end_count += start_end_count
|
||||
|
||||
# cer, wer
|
||||
sentences_gt: List[str] = [item.answer for item in context.labels]
|
||||
sentences_dt: List[str] = [
|
||||
item.recognition_results.text for item in context.preds if item.recognition_results.final_result
|
||||
]
|
||||
if IN_TEST:
|
||||
print(sentences_gt)
|
||||
print(sentences_dt)
|
||||
|
||||
sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang))
|
||||
sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang))
|
||||
if IN_TEST:
|
||||
print(sentences_gt)
|
||||
print(sentences_dt)
|
||||
|
||||
# norm & tokenize
|
||||
tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang)
|
||||
tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang)
|
||||
if IN_TEST:
|
||||
print(tokens_gt)
|
||||
print(tokens_dt)
|
||||
|
||||
# cer
|
||||
tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt)))
|
||||
one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping)
|
||||
self.one_minus_cer += one_minue_cer
|
||||
self.token_count += token_count
|
||||
self.one_minus_cer_lang[context.lang] += one_minue_cer
|
||||
|
||||
# cut-rate
|
||||
rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping)
|
||||
self.cut_rate += rate
|
||||
self.sentence_count += sentence_cnt
|
||||
self.miss_count += miss_cnt
|
||||
self.more_count += more_cnt
|
||||
|
||||
# detail-case
|
||||
self.context = context
|
||||
|
||||
# 时延
|
||||
if self.context.send_time_start_end and self.context.recv_time_start_end:
|
||||
send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0]
|
||||
recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0]
|
||||
self.send_interval.append(send_interval)
|
||||
self.last_recv_interval.append(recv_interval)
|
||||
logger.info(
|
||||
f"""第一次发送时间{self.context.send_time_start_end[0]}, \
|
||||
最后一次发送时间{self.context.send_time_start_end[-1]}, \
|
||||
发送间隔 {send_interval},
|
||||
最后一次接收时间{self.context.recv_time_start_end[-1]}, \
|
||||
接收间隔 {recv_interval}
|
||||
"""
|
||||
)
|
||||
|
||||
def post_evaluate(self):
|
||||
pass
|
||||
|
||||
def gen_result(self):
|
||||
result = {
|
||||
"query_count": self.query_count,
|
||||
"voice_count": self.voice_count,
|
||||
"pred_voice_count": self.first_delay_cnt,
|
||||
"first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10,
|
||||
"revision_delay_mean": (
|
||||
self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10
|
||||
),
|
||||
"patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter),
|
||||
"patch_token_var": var_on_counter(self.patch_unique_cnt_counter),
|
||||
"start_time_align_count": self.start_time_align_count,
|
||||
"end_time_align_count": self.end_time_align_count,
|
||||
"start_time_align_rate": self.start_time_align_count / self.sentence_count,
|
||||
"end_time_align_rate": self.end_time_align_count / self.sentence_count,
|
||||
"start_end_count": self.start_end_count,
|
||||
"one_minus_cer": self.one_minus_cer / self.query_count,
|
||||
"token_count": self.token_count,
|
||||
"miss_count": self.miss_count,
|
||||
"more_count": self.more_count,
|
||||
"sentence_count": self.sentence_count,
|
||||
"cut_rate": self.cut_rate / self.query_count,
|
||||
"fail_count": self.fail_count,
|
||||
"send_interval": self.send_interval,
|
||||
"last_recv_interval": self.last_recv_interval,
|
||||
"fail_char_contains_rate_num": self.fail_char_contains_rate_num,
|
||||
"punctuation_rate": self.pred_punctuation_num / self.punctuation_num,
|
||||
}
|
||||
for lang in self.one_minus_cer_lang:
|
||||
result["one_minus_cer_" + lang] = \
|
||||
self.one_minus_cer_lang[lang] / self.query_count_lang[lang]
|
||||
|
||||
if (
|
||||
result["first_delay_mean"]
|
||||
> float(os.getenv("FIRST_DELAY_THRESHOLD", "5"))
|
||||
or
|
||||
self.fail_char_contains_rate_num / self.voice_count > 0.1
|
||||
# or
|
||||
# result["punctuation_rate"] < 0.8
|
||||
):
|
||||
change_product_available()
|
||||
return result
|
||||
|
||||
def gen_detail_case(self):
|
||||
return self.context
|
||||
Reference in New Issue
Block a user