321 lines
13 KiB
Python
321 lines
13 KiB
Python
# coding: utf-8
|
||
|
||
import os
|
||
from collections import Counter
|
||
from copy import deepcopy
|
||
from typing import List, Tuple
|
||
|
||
import Levenshtein
|
||
import numpy as np
|
||
from schemas.context import ASRContext
|
||
from utils.logger import logger
|
||
from utils.tokenizer import Tokenizer, TokenizerType
|
||
from utils.update_submit import change_product_available
|
||
|
||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||
|
||
|
||
def text_align(context: ASRContext) -> Tuple:
|
||
start_end_count = 0
|
||
|
||
label_start_time_list = []
|
||
label_end_time_list = []
|
||
for label_item in context.labels:
|
||
label_start_time_list.append(label_item.start)
|
||
label_end_time_list.append(label_item.end)
|
||
pred_start_time_list = []
|
||
pred_end_time_list = []
|
||
sentence_start = True
|
||
for pred_item in context.preds:
|
||
if sentence_start:
|
||
pred_start_time_list.append(pred_item.recognition_results.start_time)
|
||
if pred_item.recognition_results.final_result:
|
||
pred_end_time_list.append(pred_item.recognition_results.end_time)
|
||
sentence_start = pred_item.recognition_results.final_result
|
||
# check start0 < end0 < start1 < end1 < start2 < end2 - ...
|
||
if IN_TEST:
|
||
print(pred_start_time_list)
|
||
print(pred_end_time_list)
|
||
pred_time_list = []
|
||
i, j = 0, 0
|
||
while i < len(pred_start_time_list) and j < len(pred_end_time_list):
|
||
pred_time_list.append(pred_start_time_list[i])
|
||
pred_time_list.append(pred_end_time_list[j])
|
||
i += 1
|
||
j += 1
|
||
if i < len(pred_start_time_list):
|
||
pred_time_list.append(pred_start_time_list[-1])
|
||
for i in range(1, len(pred_time_list)):
|
||
# 这里给个 600ms 的宽限
|
||
if pred_time_list[i] < pred_time_list[i - 1] - 0.6:
|
||
logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...")
|
||
logger.error(
|
||
f"当前识别的每个句子开始和结束时间分别为: \
|
||
开始时间:{pred_start_time_list}, \
|
||
结束时间:{pred_end_time_list}"
|
||
)
|
||
start_end_count += 1
|
||
# change_product_available()
|
||
# 时间前后差值 300ms 范围内
|
||
start_time_align_count = 0
|
||
end_time_align_count = 0
|
||
for label_start_time in label_start_time_list:
|
||
for pred_start_time in pred_start_time_list:
|
||
if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3:
|
||
start_time_align_count += 1
|
||
break
|
||
for label_end_time in label_end_time_list:
|
||
for pred_end_time in pred_end_time_list:
|
||
if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3:
|
||
end_time_align_count += 1
|
||
break
|
||
logger.info(
|
||
f"start-time 对齐个数 {start_time_align_count}, \
|
||
end-time 对齐个数 {end_time_align_count}\
|
||
数据集中句子总数 {len(label_start_time_list)}"
|
||
)
|
||
return start_time_align_count, end_time_align_count, start_end_count
|
||
|
||
|
||
def first_delay(context: ASRContext) -> Tuple:
|
||
first_send_time = context.preds[0].send_time
|
||
first_delay_list = []
|
||
sentence_start = True
|
||
for pred_context in context.preds:
|
||
if sentence_start:
|
||
sentence_begin_time = pred_context.recognition_results.start_time
|
||
first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time
|
||
first_delay_list.append(first_delay_time)
|
||
sentence_start = pred_context.recognition_results.final_result
|
||
if IN_TEST:
|
||
print(f"当前音频的首字延迟为{first_delay_list}")
|
||
logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s")
|
||
return np.sum(first_delay_list), len(first_delay_list)
|
||
|
||
|
||
def revision_delay(context: ASRContext):
|
||
first_send_time = context.preds[0].send_time
|
||
revision_delay_list = []
|
||
for pred_context in context.preds:
|
||
if pred_context.recognition_results.final_result:
|
||
sentence_end_time = pred_context.recognition_results.end_time
|
||
revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time
|
||
revision_delay_list.append(revision_delay_time)
|
||
|
||
if IN_TEST:
|
||
print(revision_delay_list)
|
||
logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s")
|
||
return np.sum(revision_delay_list), len(revision_delay_list)
|
||
|
||
|
||
def patch_unique_token_count(context: ASRContext):
|
||
# print(context.__dict__)
|
||
# 对于每一个返回的结果都进行 tokenize
|
||
pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds]
|
||
pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang)
|
||
# print(pred_text_list)
|
||
# print(pred_text_tokenized_list)
|
||
|
||
# 判断当前是否修改了超过 3s 内的 token 数目
|
||
## 当前句子的最开始接受时间
|
||
first_recv_time = None
|
||
## 不可修改的 token 个数
|
||
unmodified_token_cnt = 0
|
||
## 3s 的 index 位置
|
||
time_token_idx = 0
|
||
## 当前是句子的开始
|
||
final_sentence = True
|
||
|
||
## 修改了不可修改的范围
|
||
is_unmodified_token = False
|
||
|
||
for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)):
|
||
## 当前是句子的第一次返回
|
||
if final_sentence:
|
||
first_recv_time = pred_context.recv_time
|
||
unmodified_token_cnt = 0
|
||
time_token_idx = idx
|
||
final_sentence = pred_context.recognition_results.final_result
|
||
continue
|
||
final_sentence = pred_context.recognition_results.final_result
|
||
## 当前 pred 的 recv-time
|
||
pred_recv_time = pred_context.recv_time
|
||
## 最开始 3s 直接忽略
|
||
if pred_recv_time - first_recv_time < 3:
|
||
continue
|
||
## 根据历史返回信息,获得最长不可修改长度
|
||
while time_token_idx < idx:
|
||
context_pred_tmp = context.preds[time_token_idx]
|
||
context_pred_tmp_recv_time = context_pred_tmp.recv_time
|
||
tmp_tokens = pred_text_tokenized_list[time_token_idx]
|
||
if pred_recv_time - context_pred_tmp_recv_time >= 3:
|
||
unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens))
|
||
time_token_idx += 1
|
||
else:
|
||
break
|
||
## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token
|
||
last_tokens = pred_text_tokenized_list[idx - 1]
|
||
if context.lang in ['ar', 'he']:
|
||
tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1]
|
||
continue
|
||
else:
|
||
tokens_check_pre, tokens_check_now = last_tokens, now_tokens
|
||
for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]):
|
||
if token_a != token_b:
|
||
is_unmodified_token = True
|
||
break
|
||
|
||
if is_unmodified_token and int(os.getenv('test', 0)):
|
||
logger.error(
|
||
f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}"
|
||
)
|
||
if is_unmodified_token:
|
||
break
|
||
|
||
if is_unmodified_token:
|
||
logger.error("修改了不可修改的文字范围")
|
||
# change_product_available()
|
||
if int(os.getenv('test', 0)):
|
||
final_result = True
|
||
result_list = []
|
||
for tokens, pred in zip(pred_text_tokenized_list, context.preds):
|
||
if final_result:
|
||
result_list.append([])
|
||
result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time))
|
||
final_result = pred.recognition_results.final_result
|
||
for item in result_list:
|
||
logger.info(str(item))
|
||
|
||
# 记录每个 patch 的 token 个数
|
||
patch_unique_cnt_counter = Counter()
|
||
patch_unique_cnt_in_one_sentence = set()
|
||
for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds):
|
||
token_cnt = len(pred_text_tokenized)
|
||
patch_unique_cnt_in_one_sentence.add(token_cnt)
|
||
if pred_context.recognition_results.final_result:
|
||
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||
patch_unique_cnt_counter[unique_cnt] += 1
|
||
patch_unique_cnt_in_one_sentence.clear()
|
||
if context.preds and not context.preds[-1].recognition_results.final_result:
|
||
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||
patch_unique_cnt_counter[unique_cnt] += 1
|
||
# print(patch_unique_cnt_counter)
|
||
logger.info(
|
||
f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \
|
||
当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}"
|
||
)
|
||
return patch_unique_cnt_counter
|
||
|
||
|
||
def mean_on_counter(counter: Counter):
|
||
total_sum = sum(key * count for key, count in counter.items())
|
||
total_count = sum(counter.values())
|
||
return total_sum * 1.0 / total_count
|
||
|
||
|
||
def var_on_counter(counter: Counter):
|
||
total_sum = sum(key * count for key, count in counter.items())
|
||
total_count = sum(counter.values())
|
||
mean = total_sum * 1.0 / total_count
|
||
return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count
|
||
|
||
|
||
def edit_distance(arr1: List, arr2: List):
|
||
operations = Levenshtein.editops(arr1, arr2)
|
||
i = sum([1 for operation in operations if operation[0] == "insert"])
|
||
s = sum([1 for operation in operations if operation[0] == "replace"])
|
||
d = sum([1 for operation in operations if operation[0] == "delete"])
|
||
c = len(arr1) - s - d
|
||
return s, d, i, c
|
||
|
||
|
||
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||
equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt)
|
||
replace = len(tokens_gt_mapping) - insert - equal
|
||
|
||
token_count = replace + equal + delete
|
||
cer_value = (replace + delete + insert) * 1.0 / token_count
|
||
logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}")
|
||
return 1 - cer_value, token_count
|
||
|
||
|
||
def cut_rate(
|
||
tokens_gt: List[List[str]],
|
||
tokens_dt: List[List[str]],
|
||
tokens_gt_mapping: List[str],
|
||
tokens_dt_mapping: List[str],
|
||
):
|
||
sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping)
|
||
sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping)
|
||
sentence_final_token_index_gt = set(sentence_final_token_index_gt)
|
||
sentence_final_token_index_dt = set(sentence_final_token_index_dt)
|
||
sentence_count_gt = len(sentence_final_token_index_gt)
|
||
miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt)
|
||
more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt)
|
||
rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0)
|
||
return rate, sentence_count_gt, miss_count, more_count
|
||
|
||
|
||
def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]:
|
||
arr1 = deepcopy(tokens_gt)
|
||
arr2 = deepcopy(tokens_dt)
|
||
operations = Levenshtein.editops(arr1, arr2)
|
||
for op in operations[::-1]:
|
||
if op[0] == "insert":
|
||
arr1.insert(op[1], None)
|
||
elif op[0] == "delete":
|
||
arr2.insert(op[2], None)
|
||
return arr1, arr2
|
||
|
||
|
||
def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]:
|
||
"""获得原句子中每个句子尾部 token 的 index"""
|
||
token_index_list = []
|
||
token_index = 0
|
||
for token_in_one_sentence in tokens:
|
||
for _ in range(len(token_in_one_sentence)):
|
||
while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None:
|
||
token_index += 1
|
||
token_index += 1
|
||
token_index_list.append(token_index - 1)
|
||
return token_index_list
|
||
|
||
|
||
def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]:
|
||
"""use self.cut_punc to cut all sentences, merge them and put them into list"""
|
||
sentence_cut_list = []
|
||
for sentence in sentences:
|
||
sentence_list = [sentence]
|
||
sentence_tmp_list = []
|
||
for punc in [
|
||
"······",
|
||
"......",
|
||
"。",
|
||
",",
|
||
"?",
|
||
"!",
|
||
";",
|
||
":",
|
||
"...",
|
||
".",
|
||
",",
|
||
"?",
|
||
"!",
|
||
";",
|
||
":",
|
||
]:
|
||
for sentence in sentence_list:
|
||
sentence_tmp_list.extend(sentence.split(punc))
|
||
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||
sentence_list = [item for item in sentence_list if item]
|
||
|
||
if tokenizerType == TokenizerType.whitespace:
|
||
sentence_cut_list.append(" ".join(sentence_list))
|
||
else:
|
||
sentence_cut_list.append("".join(sentence_list))
|
||
|
||
return sentence_cut_list
|