# coding: utf-8 import os from collections import Counter, defaultdict from itertools import chain from typing import List from schemas.context import ASRContext from utils.logger import logger from utils.metrics import cer, cut_rate, cut_sentence, first_delay from utils.metrics import mean_on_counter, patch_unique_token_count from utils.metrics import revision_delay, text_align, token_mapping from utils.metrics import var_on_counter from utils.tokenizer import TOKENIZER_MAPPING, Tokenizer from utils.update_submit import change_product_available IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", 1) is None class BaseEvaluator: def __init__(self) -> None: self.query_count = 0 # query 数目(语音数目) self.voice_count = 0 self.fail_count = 0 # 失败数目 # 首字延迟 self.first_delay_sum = 0 self.first_delay_cnt = 0 # 修正延迟 self.revision_delay_sum = 0 self.revision_delay_cnt = 0 # patch token 信息 self.patch_unique_cnt_counter = Counter() # text align count self.start_time_align_count = 0 self.end_time_align_count = 0 self.start_end_count = 0 # 1-cer self.one_minus_cer = 0 self.token_count = 0 # 1-cer language self.one_minus_cer_lang = defaultdict(int) self.query_count_lang = defaultdict(int) # sentence-cut self.miss_count = 0 self.more_count = 0 self.sentence_count = 0 self.cut_rate = 0 # detail-case self.context = ASRContext() # 时延 self.send_interval = [] self.last_recv_interval = [] # 字含量不达标数 self.fail_char_contains_rate_num = 0 # 标点符号 self.punctuation_num = 0 self.pred_punctuation_num = 0 def evaluate(self, context: ASRContext): self.query_count += 1 self.query_count_lang[context.lang] += 1 voice_count = len(context.labels) self.voice_count += voice_count self.punctuation_num += context.punctuation_num self.pred_punctuation_num += context.pred_punctuation_num if not context.fail: # 首字延迟 first_delay_sum, first_delay_cnt = first_delay(context) self.first_delay_sum += first_delay_sum self.first_delay_cnt += first_delay_cnt # 修正延迟 revision_delay_sum, revision_delay_cnt = revision_delay(context) self.revision_delay_sum += revision_delay_sum self.revision_delay_cnt += revision_delay_cnt # patch token 信息 counter = patch_unique_token_count(context) self.patch_unique_cnt_counter += counter else: self.fail_count += 1 self.fail_char_contains_rate_num += context.fail_char_contains_rate_num # text align count start_time_align_count, end_time_align_count, start_end_count = text_align(context) self.start_time_align_count += start_time_align_count self.end_time_align_count += end_time_align_count self.start_end_count += start_end_count # cer, wer sentences_gt: List[str] = [item.answer for item in context.labels] sentences_dt: List[str] = [ item.recognition_results.text for item in context.preds if item.recognition_results.final_result ] if IN_TEST: print(sentences_gt) print(sentences_dt) sentences_gt: List[str] = cut_sentence(sentences_gt, TOKENIZER_MAPPING.get(context.lang)) sentences_dt: List[str] = cut_sentence(sentences_dt, TOKENIZER_MAPPING.get(context.lang)) if IN_TEST: print(sentences_gt) print(sentences_dt) # norm & tokenize tokens_gt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_gt, context.lang) tokens_dt: List[List[str]] = Tokenizer.norm_and_tokenize(sentences_dt, context.lang) if IN_TEST: print(tokens_gt) print(tokens_dt) # cer tokens_gt_mapping, tokens_dt_mapping = token_mapping(list(chain(*tokens_gt)), list(chain(*tokens_dt))) one_minue_cer, token_count = cer(tokens_gt_mapping, tokens_dt_mapping) self.one_minus_cer += one_minue_cer self.token_count += token_count self.one_minus_cer_lang[context.lang] += one_minue_cer # cut-rate rate, sentence_cnt, miss_cnt, more_cnt = cut_rate(tokens_gt, tokens_dt, tokens_gt_mapping, tokens_dt_mapping) self.cut_rate += rate self.sentence_count += sentence_cnt self.miss_count += miss_cnt self.more_count += more_cnt # detail-case self.context = context # 时延 if self.context.send_time_start_end and self.context.recv_time_start_end: send_interval = self.context.send_time_start_end[1] - self.context.send_time_start_end[0] recv_interval = self.context.recv_time_start_end[1] - self.context.send_time_start_end[0] self.send_interval.append(send_interval) self.last_recv_interval.append(recv_interval) logger.info( f"""第一次发送时间{self.context.send_time_start_end[0]}, \ 最后一次发送时间{self.context.send_time_start_end[-1]}, \ 发送间隔 {send_interval}, 最后一次接收时间{self.context.recv_time_start_end[-1]}, \ 接收间隔 {recv_interval} """ ) def post_evaluate(self): pass def gen_result(self): result = { "query_count": self.query_count, "voice_count": self.voice_count, "pred_voice_count": self.first_delay_cnt, "first_delay_mean": self.first_delay_sum / self.first_delay_cnt if self.first_delay_cnt > 0 else 10, "revision_delay_mean": ( self.revision_delay_sum / self.revision_delay_cnt if self.revision_delay_cnt > 0 else 10 ), "patch_token_mean": mean_on_counter(self.patch_unique_cnt_counter), "patch_token_var": var_on_counter(self.patch_unique_cnt_counter), "start_time_align_count": self.start_time_align_count, "end_time_align_count": self.end_time_align_count, "start_time_align_rate": self.start_time_align_count / self.sentence_count, "end_time_align_rate": self.end_time_align_count / self.sentence_count, "start_end_count": self.start_end_count, "one_minus_cer": self.one_minus_cer / self.query_count, "token_count": self.token_count, "miss_count": self.miss_count, "more_count": self.more_count, "sentence_count": self.sentence_count, "cut_rate": self.cut_rate / self.query_count, "fail_count": self.fail_count, "send_interval": self.send_interval, "last_recv_interval": self.last_recv_interval, "fail_char_contains_rate_num": self.fail_char_contains_rate_num, "punctuation_rate": self.pred_punctuation_num / self.punctuation_num, } for lang in self.one_minus_cer_lang: result["one_minus_cer_" + lang] = \ self.one_minus_cer_lang[lang] / self.query_count_lang[lang] if ( result["first_delay_mean"] > float(os.getenv("FIRST_DELAY_THRESHOLD", "5")) or self.fail_char_contains_rate_num / self.voice_count > 0.1 # or # result["punctuation_rate"] < 0.8 ): change_product_available() return result def gen_detail_case(self): return self.context