196 lines
7.7 KiB
Python
196 lines
7.7 KiB
Python
# coding: utf-8
|
|
|
|
import os
|
|
from collections import Counter, defaultdict
|
|
from itertools import chain
|
|
from typing import List
|
|
|
|
from schemas.context import ASRContext
|
|
from utils.logger import logger
|
|
from utils.metrics import cer, cut_rate, cut_sentence, first_delay
|
|
from utils.metrics import mean_on_counter, patch_unique_token_count
|
|
from utils.metrics import revision_delay, text_align, token_mapping
|
|
from utils.metrics import var_on_counter
|
|
from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer
|
|
from utils.update_submit import change_product_available
|
|
|
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None
|
|
|
|
|
|
class BaseEvaluator:
|
|
def __init__(self) -> None:
|
|
self.query_count = 0 # query 数目(语音数目)
|
|
self.voice_count = 0
|
|
self.fail_count = 0 # 失败数目
|
|
# 首字延迟
|
|
self.first_delay_sum = 0
|
|
self.first_delay_cnt = 0
|
|
# 修正延迟
|
|
self.revision_delay_sum = 0
|
|
self.revision_delay_cnt = 0
|
|
# patch token 信息
|
|
self.patch_unique_cnt_counter = Counter()
|
|
# text align count
|
|
self.start_time_align_count = 0
|
|
self.end_time_align_count = 0
|
|
self.start_end_count = 0
|
|
# 1-cer
|
|
self.one_minus_cer = 0
|
|
self.token_count = 0
|
|
# 1-cer language
|
|
self.one_minus_cer_lang = defaultdict(int)
|
|
self.query_count_lang = defaultdict(int)
|
|
# sentence-cut
|
|
self.miss_count = 0
|
|
self.more_count = 0
|
|
self.sentence_count = 0
|
|
self.cut_rate = 0
|
|
# detail-case
|
|
self.context = ASRContext()
|
|
# 时延
|
|
self.send_interval = []
|
|
self.last_recv_interval = []
|
|
# 字含量不达标数
|
|
self.fail_char_contains_rate_num = 0
|
|
# 标点符号
|
|
self.punctuation_num = 0
|
|
self.pred_punctuation_num = 0
|
|
|
|
def evaluate(self, context: ASRContext):
|
|
self.query_count += 1
|
|
self.query_count_lang[context.lang] += 1
|
|
|
|
voice_count = len(context.labels)
|
|
self.voice_count += voice_count
|
|
|
|
self.punctuation_num += context.punctuation_num
|
|
self.pred_punctuation_num += context.pred_punctuation_num
|
|
|
|
if not context.fail:
|
|
# 首字延迟
|
|
first_delay_sum, first_delay_cnt = first_delay(context)
|
|
self.first_delay_sum += first_delay_sum
|
|
self.first_delay_cnt += first_delay_cnt
|
|
|
|
# 修正延迟
|
|
revision_delay_sum, revision_delay_cnt = revision_delay(context)
|
|
self.revision_delay_sum += revision_delay_sum
|
|
self.revision_delay_cnt += revision_delay_cnt
|
|
|
|
# patch token 信息
|
|
counter = patch_unique_token_count(context)
|
|
self.patch_unique_cnt_counter += counter
|
|
else:
|
|
self.fail_count += 1
|
|
|
|
self.fail_char_contains_rate_num += context.fail_char_contains_rate_num
|
|
|
|
# text align count
|
|
start_time_align_count, end_time_align_count, start_end_count = text_align(context)
|
|
self.start_time_align_count += start_time_align_count
|
|
self.end_time_align_count += end_time_align_count
|
|
self.start_end_count += start_end_count
|
|
|
|
# cer, wer
|
|
sentences_gt: List[str] = [item.answer for item in context.labels]
|
|
sentences_dt: List[str] = [
|
|
item.recognition_results.text for item in context.preds if item.recognition_results.final_result
|
|
]
|
|
if IN_TEST:
|
|
print(sentences_gt)
|
|
print(sentences_dt)
|
|
|
|
sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang))
|
|
sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang))
|
|
if IN_TEST:
|
|
print(sentences_gt)
|
|
print(sentences_dt)
|
|
|
|
# norm & tokenize
|
|
tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang)
|
|
tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang)
|
|
if IN_TEST:
|
|
print(tokens_gt)
|
|
print(tokens_dt)
|
|
|
|
# cer
|
|
tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt)))
|
|
one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping)
|
|
self.one_minus_cer += one_minue_cer
|
|
self.token_count += token_count
|
|
self.one_minus_cer_lang[context.lang] += one_minue_cer
|
|
|
|
# cut-rate
|
|
rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping)
|
|
self.cut_rate += rate
|
|
self.sentence_count += sentence_cnt
|
|
self.miss_count += miss_cnt
|
|
self.more_count += more_cnt
|
|
|
|
# detail-case
|
|
self.context = context
|
|
|
|
# 时延
|
|
if self.context.send_time_start_end and self.context.recv_time_start_end:
|
|
send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0]
|
|
recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0]
|
|
self.send_interval.append(send_interval)
|
|
self.last_recv_interval.append(recv_interval)
|
|
logger.info(
|
|
f"""第一次发送时间{self.context.send_time_start_end[0]}, \
|
|
最后一次发送时间{self.context.send_time_start_end[-1]}, \
|
|
发送间隔 {send_interval},
|
|
最后一次接收时间{self.context.recv_time_start_end[-1]}, \
|
|
接收间隔 {recv_interval}
|
|
"""
|
|
)
|
|
|
|
def post_evaluate(self):
|
|
pass
|
|
|
|
def gen_result(self):
|
|
result = {
|
|
"query_count": self.query_count,
|
|
"voice_count": self.voice_count,
|
|
"pred_voice_count": self.first_delay_cnt,
|
|
"first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10,
|
|
"revision_delay_mean": (
|
|
self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10
|
|
),
|
|
"patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter),
|
|
"patch_token_var": var_on_counter(self.patch_unique_cnt_counter),
|
|
"start_time_align_count": self.start_time_align_count,
|
|
"end_time_align_count": self.end_time_align_count,
|
|
"start_time_align_rate": self.start_time_align_count / self.sentence_count,
|
|
"end_time_align_rate": self.end_time_align_count / self.sentence_count,
|
|
"start_end_count": self.start_end_count,
|
|
"one_minus_cer": self.one_minus_cer / self.query_count,
|
|
"token_count": self.token_count,
|
|
"miss_count": self.miss_count,
|
|
"more_count": self.more_count,
|
|
"sentence_count": self.sentence_count,
|
|
"cut_rate": self.cut_rate / self.query_count,
|
|
"fail_count": self.fail_count,
|
|
"send_interval": self.send_interval,
|
|
"last_recv_interval": self.last_recv_interval,
|
|
"fail_char_contains_rate_num": self.fail_char_contains_rate_num,
|
|
"punctuation_rate": self.pred_punctuation_num / self.punctuation_num,
|
|
}
|
|
for lang in self.one_minus_cer_lang:
|
|
result["one_minus_cer_" + lang] = \
|
|
self.one_minus_cer_lang[lang] / self.query_count_lang[lang]
|
|
|
|
if (
|
|
result["first_delay_mean"]
|
|
> float(os.getenv("FIRST_DELAY_THRESHOLD", "5"))
|
|
or
|
|
self.fail_char_contains_rate_num / self.voice_count > 0.1
|
|
# or
|
|
# result["punctuation_rate"] < 0.8
|
|
):
|
|
change_product_available()
|
|
return result
|
|
|
|
def gen_detail_case(self):
|
|
return self.context
|