# coding: utf-8 import os from collections import Counter from copy import deepcopy from typing import List, Tuple import Levenshtein import numpy as np from schemas.context import ASRContext from utils.logger import logger from utils.tokenizer import Tokenizer, TokenizerType from utils.update_submit import change_product_available IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None def text_align(context: ASRContext) -> Tuple: start_end_count = 0 label_start_time_list = [] label_end_time_list = [] for label_item in context.labels: label_start_time_list.append(label_item.start) label_end_time_list.append(label_item.end) pred_start_time_list = [] pred_end_time_list = [] sentence_start = True for pred_item in context.preds: if sentence_start: pred_start_time_list.append(pred_item.recognition_results.start_time) if pred_item.recognition_results.final_result: pred_end_time_list.append(pred_item.recognition_results.end_time) sentence_start = pred_item.recognition_results.final_result # check start0 < end0 < start1 < end1 < start2 < end2 - ... if IN_TEST: print(pred_start_time_list) print(pred_end_time_list) pred_time_list = [] i, j = 0, 0 while i < len(pred_start_time_list) and j < len(pred_end_time_list): pred_time_list.append(pred_start_time_list[i]) pred_time_list.append(pred_end_time_list[j]) i += 1 j += 1 if i < len(pred_start_time_list): pred_time_list.append(pred_start_time_list[-1]) for i in range(1, len(pred_time_list)): # 这里给个 600ms 的宽限 if pred_time_list[i] < pred_time_list[i - 1] - 0.6: logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...") logger.error( f"当前识别的每个句子开始和结束时间分别为: \ 开始时间:{pred_start_time_list}, \ 结束时间:{pred_end_time_list}" ) start_end_count += 1 # change_product_available() # 时间前后差值 300ms 范围内 start_time_align_count = 0 end_time_align_count = 0 for label_start_time in label_start_time_list: for pred_start_time in pred_start_time_list: if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3: start_time_align_count += 1 break for label_end_time in label_end_time_list: for pred_end_time in pred_end_time_list: if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3: end_time_align_count += 1 break logger.info( f"start-time 对齐个数 {start_time_align_count}, \ end-time 对齐个数 {end_time_align_count}\ 数据集中句子总数 {len(label_start_time_list)}" ) return start_time_align_count, end_time_align_count, start_end_count def first_delay(context: ASRContext) -> Tuple: first_send_time = context.preds[0].send_time first_delay_list = [] sentence_start = True for pred_context in context.preds: if sentence_start: sentence_begin_time = pred_context.recognition_results.start_time first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time first_delay_list.append(first_delay_time) sentence_start = pred_context.recognition_results.final_result if IN_TEST: print(f"当前音频的首字延迟为{first_delay_list}") logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s") return np.sum(first_delay_list), len(first_delay_list) def revision_delay(context: ASRContext): first_send_time = context.preds[0].send_time revision_delay_list = [] for pred_context in context.preds: if pred_context.recognition_results.final_result: sentence_end_time = pred_context.recognition_results.end_time revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time revision_delay_list.append(revision_delay_time) if IN_TEST: print(revision_delay_list) logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s") return np.sum(revision_delay_list), len(revision_delay_list) def patch_unique_token_count(context: ASRContext): # print(context.__dict__) # 对于每一个返回的结果都进行 tokenize pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds] pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang) # print(pred_text_list) # print(pred_text_tokenized_list) # 判断当前是否修改了超过 3s 内的 token 数目 ## 当前句子的最开始接受时间 first_recv_time = None ## 不可修改的 token 个数 unmodified_token_cnt = 0 ## 3s 的 index 位置 time_token_idx = 0 ## 当前是句子的开始 final_sentence = True ## 修改了不可修改的范围 is_unmodified_token = False for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)): ## 当前是句子的第一次返回 if final_sentence: first_recv_time = pred_context.recv_time unmodified_token_cnt = 0 time_token_idx = idx final_sentence = pred_context.recognition_results.final_result continue final_sentence = pred_context.recognition_results.final_result ## 当前 pred 的 recv-time pred_recv_time = pred_context.recv_time ## 最开始 3s 直接忽略 if pred_recv_time - first_recv_time < 3: continue ## 根据历史返回信息,获得最长不可修改长度 while time_token_idx < idx: context_pred_tmp = context.preds[time_token_idx] context_pred_tmp_recv_time = context_pred_tmp.recv_time tmp_tokens = pred_text_tokenized_list[time_token_idx] if pred_recv_time - context_pred_tmp_recv_time >= 3: unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens)) time_token_idx += 1 else: break ## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token last_tokens = pred_text_tokenized_list[idx - 1] if context.lang in ['ar', 'he']: tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1] continue else: tokens_check_pre, tokens_check_now = last_tokens, now_tokens for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]): if token_a != token_b: is_unmodified_token = True break if is_unmodified_token and int(os.getenv('test', 0)): logger.error( f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}" ) if is_unmodified_token: break if is_unmodified_token: logger.error("修改了不可修改的文字范围") # change_product_available() if int(os.getenv('test', 0)): final_result = True result_list = [] for tokens, pred in zip(pred_text_tokenized_list, context.preds): if final_result: result_list.append([]) result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time)) final_result = pred.recognition_results.final_result for item in result_list: logger.info(str(item)) # 记录每个 patch 的 token 个数 patch_unique_cnt_counter = Counter() patch_unique_cnt_in_one_sentence = set() for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds): token_cnt = len(pred_text_tokenized) patch_unique_cnt_in_one_sentence.add(token_cnt) if pred_context.recognition_results.final_result: for unique_cnt in patch_unique_cnt_in_one_sentence: patch_unique_cnt_counter[unique_cnt] += 1 patch_unique_cnt_in_one_sentence.clear() if context.preds and not context.preds[-1].recognition_results.final_result: for unique_cnt in patch_unique_cnt_in_one_sentence: patch_unique_cnt_counter[unique_cnt] += 1 # print(patch_unique_cnt_counter) logger.info( f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \ 当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}" ) return patch_unique_cnt_counter def mean_on_counter(counter: Counter): total_sum = sum(key * count for key, count in counter.items()) total_count = sum(counter.values()) return total_sum * 1.0 / total_count def var_on_counter(counter: Counter): total_sum = sum(key * count for key, count in counter.items()) total_count = sum(counter.values()) mean = total_sum * 1.0 / total_count return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count def edit_distance(arr1: List, arr2: List): operations = Levenshtein.editops(arr1, arr2) i = sum([1 for operation in operations if operation[0] == "insert"]) s = sum([1 for operation in operations if operation[0] == "replace"]) d = sum([1 for operation in operations if operation[0] == "delete"]) c = len(arr1) - s - d return s, d, i, c def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]): """输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt""" insert = sum(1 for item in tokens_gt_mapping if item is None) delete = sum(1 for item in tokens_dt_mapping if item is None) equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt) replace = len(tokens_gt_mapping) - insert - equal token_count = replace + equal + delete cer_value = (replace + delete + insert) * 1.0 / token_count logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}") return 1 - cer_value, token_count def cut_rate( tokens_gt: List[List[str]], tokens_dt: List[List[str]], tokens_gt_mapping: List[str], tokens_dt_mapping: List[str], ): sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping) sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping) sentence_final_token_index_gt = set(sentence_final_token_index_gt) sentence_final_token_index_dt = set(sentence_final_token_index_dt) sentence_count_gt = len(sentence_final_token_index_gt) miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt) more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt) rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0) return rate, sentence_count_gt, miss_count, more_count def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]: arr1 = deepcopy(tokens_gt) arr2 = deepcopy(tokens_dt) operations = Levenshtein.editops(arr1, arr2) for op in operations[::-1]: if op[0] == "insert": arr1.insert(op[1], None) elif op[0] == "delete": arr2.insert(op[2], None) return arr1, arr2 def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]: """获得原句子中每个句子尾部 token 的 index""" token_index_list = [] token_index = 0 for token_in_one_sentence in tokens: for _ in range(len(token_in_one_sentence)): while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None: token_index += 1 token_index += 1 token_index_list.append(token_index - 1) return token_index_list def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]: """use self.cut_punc to cut all sentences, merge them and put them into list""" sentence_cut_list = [] for sentence in sentences: sentence_list = [sentence] sentence_tmp_list = [] for punc in [ "······", "......", "。", ",", "?", "!", ";", ":", "...", ".", ",", "?", "!", ";", ":", ]: for sentence in sentence_list: sentence_tmp_list.extend(sentence.split(punc)) sentence_list, sentence_tmp_list = sentence_tmp_list, [] sentence_list = [item for item in sentence_list if item] if tokenizerType == TokenizerType.whitespace: sentence_cut_list.append(" ".join(sentence_list)) else: sentence_cut_list.append("".join(sentence_list)) return sentence_cut_list