Files
enginex-c_series-asr/utils/metrics.py
aceforeverd a4ec58a45e init
2025-08-28 18:46:56 +08:00

321 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding: utf-8
import os
from collections import Counter
from copy import deepcopy
from typing import List, Tuple
import Levenshtein
import numpy as np
from schemas.context import ASRContext
from utils.logger import logger
from utils.tokenizer import Tokenizer, TokenizerType
from utils.update_submit import change_product_available
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
def text_align(context: ASRContext) -> Tuple:
start_end_count = 0
label_start_time_list = []
label_end_time_list = []
for label_item in context.labels:
label_start_time_list.append(label_item.start)
label_end_time_list.append(label_item.end)
pred_start_time_list = []
pred_end_time_list = []
sentence_start = True
for pred_item in context.preds:
if sentence_start:
pred_start_time_list.append(pred_item.recognition_results.start_time)
if pred_item.recognition_results.final_result:
pred_end_time_list.append(pred_item.recognition_results.end_time)
sentence_start = pred_item.recognition_results.final_result
# check start0 < end0 < start1 < end1 < start2 < end2 - ...
if IN_TEST:
print(pred_start_time_list)
print(pred_end_time_list)
pred_time_list = []
i, j = 0, 0
while i < len(pred_start_time_list) and j < len(pred_end_time_list):
pred_time_list.append(pred_start_time_list[i])
pred_time_list.append(pred_end_time_list[j])
i += 1
j += 1
if i < len(pred_start_time_list):
pred_time_list.append(pred_start_time_list[-1])
for i in range(1, len(pred_time_list)):
# 这里给个 600ms 的宽限
if pred_time_list[i] < pred_time_list[i - 1] - 0.6:
logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...")
logger.error(
f"当前识别的每个句子开始和结束时间分别为: \
开始时间:{pred_start_time_list}, \
结束时间:{pred_end_time_list}"
)
start_end_count += 1
# change_product_available()
# 时间前后差值 300ms 范围内
start_time_align_count = 0
end_time_align_count = 0
for label_start_time in label_start_time_list:
for pred_start_time in pred_start_time_list:
if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3:
start_time_align_count += 1
break
for label_end_time in label_end_time_list:
for pred_end_time in pred_end_time_list:
if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3:
end_time_align_count += 1
break
logger.info(
f"start-time 对齐个数 {start_time_align_count}, \
end-time 对齐个数 {end_time_align_count}\
数据集中句子总数 {len(label_start_time_list)}"
)
return start_time_align_count, end_time_align_count, start_end_count
def first_delay(context: ASRContext) -> Tuple:
first_send_time = context.preds[0].send_time
first_delay_list = []
sentence_start = True
for pred_context in context.preds:
if sentence_start:
sentence_begin_time = pred_context.recognition_results.start_time
first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time
first_delay_list.append(first_delay_time)
sentence_start = pred_context.recognition_results.final_result
if IN_TEST:
print(f"当前音频的首字延迟为{first_delay_list}")
logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s")
return np.sum(first_delay_list), len(first_delay_list)
def revision_delay(context: ASRContext):
first_send_time = context.preds[0].send_time
revision_delay_list = []
for pred_context in context.preds:
if pred_context.recognition_results.final_result:
sentence_end_time = pred_context.recognition_results.end_time
revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time
revision_delay_list.append(revision_delay_time)
if IN_TEST:
print(revision_delay_list)
logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s")
return np.sum(revision_delay_list), len(revision_delay_list)
def patch_unique_token_count(context: ASRContext):
# print(context.__dict__)
# 对于每一个返回的结果都进行 tokenize
pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds]
pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang)
# print(pred_text_list)
# print(pred_text_tokenized_list)
# 判断当前是否修改了超过 3s 内的 token 数目
## 当前句子的最开始接受时间
first_recv_time = None
## 不可修改的 token 个数
unmodified_token_cnt = 0
## 3s 的 index 位置
time_token_idx = 0
## 当前是句子的开始
final_sentence = True
## 修改了不可修改的范围
is_unmodified_token = False
for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)):
## 当前是句子的第一次返回
if final_sentence:
first_recv_time = pred_context.recv_time
unmodified_token_cnt = 0
time_token_idx = idx
final_sentence = pred_context.recognition_results.final_result
continue
final_sentence = pred_context.recognition_results.final_result
## 当前 pred 的 recv-time
pred_recv_time = pred_context.recv_time
## 最开始 3s 直接忽略
if pred_recv_time - first_recv_time < 3:
continue
## 根据历史返回信息,获得最长不可修改长度
while time_token_idx < idx:
context_pred_tmp = context.preds[time_token_idx]
context_pred_tmp_recv_time = context_pred_tmp.recv_time
tmp_tokens = pred_text_tokenized_list[time_token_idx]
if pred_recv_time - context_pred_tmp_recv_time >= 3:
unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens))
time_token_idx += 1
else:
break
## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token
last_tokens = pred_text_tokenized_list[idx - 1]
if context.lang in ['ar', 'he']:
tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1]
continue
else:
tokens_check_pre, tokens_check_now = last_tokens, now_tokens
for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]):
if token_a != token_b:
is_unmodified_token = True
break
if is_unmodified_token and int(os.getenv('test', 0)):
logger.error(
f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}"
)
if is_unmodified_token:
break
if is_unmodified_token:
logger.error("修改了不可修改的文字范围")
# change_product_available()
if int(os.getenv('test', 0)):
final_result = True
result_list = []
for tokens, pred in zip(pred_text_tokenized_list, context.preds):
if final_result:
result_list.append([])
result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time))
final_result = pred.recognition_results.final_result
for item in result_list:
logger.info(str(item))
# 记录每个 patch 的 token 个数
patch_unique_cnt_counter = Counter()
patch_unique_cnt_in_one_sentence = set()
for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds):
token_cnt = len(pred_text_tokenized)
patch_unique_cnt_in_one_sentence.add(token_cnt)
if pred_context.recognition_results.final_result:
for unique_cnt in patch_unique_cnt_in_one_sentence:
patch_unique_cnt_counter[unique_cnt] += 1
patch_unique_cnt_in_one_sentence.clear()
if context.preds and not context.preds[-1].recognition_results.final_result:
for unique_cnt in patch_unique_cnt_in_one_sentence:
patch_unique_cnt_counter[unique_cnt] += 1
# print(patch_unique_cnt_counter)
logger.info(
f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \
当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}"
)
return patch_unique_cnt_counter
def mean_on_counter(counter: Counter):
total_sum = sum(key * count for key, count in counter.items())
total_count = sum(counter.values())
return total_sum * 1.0 / total_count
def var_on_counter(counter: Counter):
total_sum = sum(key * count for key, count in counter.items())
total_count = sum(counter.values())
mean = total_sum * 1.0 / total_count
return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count
def edit_distance(arr1: List, arr2: List):
operations = Levenshtein.editops(arr1, arr2)
i = sum([1 for operation in operations if operation[0] == "insert"])
s = sum([1 for operation in operations if operation[0] == "replace"])
d = sum([1 for operation in operations if operation[0] == "delete"])
c = len(arr1) - s - d
return s, d, i, c
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
insert = sum(1 for item in tokens_gt_mapping if item is None)
delete = sum(1 for item in tokens_dt_mapping if item is None)
equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt)
replace = len(tokens_gt_mapping) - insert - equal
token_count = replace + equal + delete
cer_value = (replace + delete + insert) * 1.0 / token_count
logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}")
return 1 - cer_value, token_count
def cut_rate(
tokens_gt: List[List[str]],
tokens_dt: List[List[str]],
tokens_gt_mapping: List[str],
tokens_dt_mapping: List[str],
):
sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping)
sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping)
sentence_final_token_index_gt = set(sentence_final_token_index_gt)
sentence_final_token_index_dt = set(sentence_final_token_index_dt)
sentence_count_gt = len(sentence_final_token_index_gt)
miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt)
more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt)
rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0)
return rate, sentence_count_gt, miss_count, more_count
def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]:
arr1 = deepcopy(tokens_gt)
arr2 = deepcopy(tokens_dt)
operations = Levenshtein.editops(arr1, arr2)
for op in operations[::-1]:
if op[0] == "insert":
arr1.insert(op[1], None)
elif op[0] == "delete":
arr2.insert(op[2], None)
return arr1, arr2
def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]:
"""获得原句子中每个句子尾部 token 的 index"""
token_index_list = []
token_index = 0
for token_in_one_sentence in tokens:
for _ in range(len(token_in_one_sentence)):
while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None:
token_index += 1
token_index += 1
token_index_list.append(token_index - 1)
return token_index_list
def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]:
"""use self.cut_punc to cut all sentences, merge them and put them into list"""
sentence_cut_list = []
for sentence in sentences:
sentence_list = [sentence]
sentence_tmp_list = []
for punc in [
"······",
"......",
"",
"",
"",
"",
"",
"",
"...",
".",
",",
"?",
"!",
";",
":",
]:
for sentence in sentence_list:
sentence_tmp_list.extend(sentence.split(punc))
sentence_list, sentence_tmp_list = sentence_tmp_list, []
sentence_list = [item for item in sentence_list if item]
if tokenizerType == TokenizerType.whitespace:
sentence_cut_list.append(" ".join(sentence_list))
else:
sentence_cut_list.append("".join(sentence_list))
return sentence_cut_list