Files
enginex-bi_series-vc-cnn/utils/evaluator_plus.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

294 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Dict, List, Tuple
import Levenshtein
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel, StreamWordsModel
from utils.metrics import Tokenizer
from utils.metrics_plus import replace_general_punc
from utils.tokenizer import TOKENIZER_MAPPING
def evaluate_editops(
query_data: QueryData, recognition_results: List[StreamDataModel]
) -> Tuple[float, int, Dict[int, int], Dict[int, int], float, float]:
"""返回cer 句子总数 首字对齐情况 尾字对齐情况 首字时间差值和 尾字时间差值和
对齐情况为 时间差值->对齐数"""
recognition_results = deepcopy(recognition_results)
lang = query_data.lang
voices = query_data.voice
sentences_pred = [
recognition_result.text for recognition_result in recognition_results
]
sentences_label = [item.answer for item in voices]
tokenizer_type = TOKENIZER_MAPPING[lang]
sentences_pred = replace_general_punc(sentences_pred, tokenizer_type)
sentences_label = replace_general_punc(sentences_label, tokenizer_type)
# norm & tokenize
tokens_pred = Tokenizer.norm_and_tokenize(sentences_pred, lang)
tokens_label = Tokenizer.norm_and_tokenize(sentences_label, lang)
normed_words = []
for recognition_result in recognition_results:
words = list(map(lambda x: x.text, recognition_result.words))
normed_words.extend(words)
normed_words = replace_general_punc(normed_words, tokenizer_type)
normed_words = Tokenizer.norm(normed_words, lang)
# 预测中的结果进行相同的norm和tokenize操作
normed_word_index = 0
for recognition_result in recognition_results:
next_index = normed_word_index + len(recognition_result.words)
tokens_words = Tokenizer.tokenize(
normed_words[normed_word_index:next_index], lang
)
normed_word_index = next_index
stream_words: List[StreamWordsModel] = []
# 将原words进行norm和tokenize操作后赋值为对应原word的时间
for raw_stream_word, tokens_word in zip(
recognition_result.words, tokens_words
):
for word in tokens_word:
stream_words.append(
StreamWordsModel(
text=word,
start_time=raw_stream_word.start_time,
end_time=raw_stream_word.end_time,
)
)
recognition_result.words = stream_words
# 将words对应上对分词后的词从而使得分词后的词有时间
pred_word_time: List[StreamWordsModel] = []
for token_pred, recognition_result in zip(tokens_pred, recognition_results):
word_index = 0
for word in recognition_result.words:
try:
token_index = token_pred.index(word.text, word_index)
for i in range(word_index, token_index + 1):
pred_word_time.append(
StreamWordsModel(
text=token_pred[i],
start_time=word.start_time,
end_time=word.end_time,
)
)
word_index = token_index + 1
except ValueError:
pass
if len(recognition_result.words) > 0:
word = recognition_result.words[-1]
start_time = word.start_time
end_time = word.end_time
else:
start_time = recognition_result.start_time
end_time = recognition_result.end_time
for i in range(word_index, len(token_pred)):
pred_word_time.append(
StreamWordsModel(
text=token_pred[i],
start_time=start_time,
end_time=end_time,
)
)
# 记录label每句话的首字尾字对应分词后的位置
index = 0
label_firstword_index: List[int] = []
label_lastword_index: List[int] = []
for token_label in tokens_label:
label_firstword_index.append(index)
index += len(token_label)
label_lastword_index.append(index - 1)
# cer
flat_tokens_pred = list(chain(*tokens_pred))
flat_tokens_label = list(chain(*tokens_label))
ops = Levenshtein.editops(flat_tokens_pred, flat_tokens_label)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
cer = (insert + delete + replace) / len(flat_tokens_label)
# 计算每个token在编辑后的下标位置
pred_offset = [0] * (len(flat_tokens_pred) + 1)
label_offset = [0] * (len(flat_tokens_label) + 1)
for op in ops:
if op[0] == "insert":
pred_offset[op[1]] += 1
elif op[0] == "delete":
label_offset[op[2]] += 1
pred_indexs = [pred_offset[0]]
for i in range(1, len(flat_tokens_pred)):
pred_indexs.append(pred_indexs[i - 1] + pred_offset[i] + 1)
label_indexs = [label_offset[0]]
for i in range(1, len(flat_tokens_label)):
label_indexs.append(label_indexs[i - 1] + label_offset[i] + 1)
# 计算每个label中首字和尾字对应的时间
align_start = {100: 0, 200: 0, 300: 0, 500: 0}
align_end = {100: 0, 200: 0, 300: 0, 500: 0}
first_word_distance_sum = 0.0
last_word_distance_sum = 0.0
for firstword_index, lastword_index, voice in zip(
label_firstword_index, label_lastword_index, voices
):
label_index = label_indexs[firstword_index]
label_in_pred_index = upper_bound(label_index, pred_indexs)
if label_in_pred_index != -1:
distance = abs(
voice.start - pred_word_time[label_in_pred_index].start_time
)
if label_in_pred_index > 0:
distance = min(
distance,
abs(
voice.start
- pred_word_time[label_in_pred_index - 1].start_time
),
)
else:
distance = abs(voice.start - pred_word_time[-1].start_time)
for limit in align_start.keys():
if distance <= limit / 1000:
align_start[limit] += 1
first_word_distance_sum += distance
label_index = label_indexs[lastword_index]
label_in_pred_index = lower_bound(label_index, pred_indexs)
if label_in_pred_index != -1:
distance = abs(
voice.end - pred_word_time[label_in_pred_index].end_time
)
if label_in_pred_index < len(pred_word_time) - 1:
distance = min(
distance,
abs(
voice.end
- pred_word_time[label_in_pred_index + 1].end_time
),
)
else:
distance = abs(voice.end - pred_word_time[0].end_time)
for limit in align_end.keys():
if distance <= limit / 1000:
align_end[limit] += 1
last_word_distance_sum += distance
return (
cer,
len(voices),
align_start,
align_end,
first_word_distance_sum,
last_word_distance_sum,
)
def evaluate_punctuation(
query_data: QueryData, recognition_results: List[StreamDataModel]
) -> Tuple[int, int, int, int]:
"""评估标点符号指标 返回预测中标点数 label中标点数 预测中句子标点数 label中句子标点数"""
punctuation_mapping = defaultdict(lambda: [",", ".", "!", "?"])
punctuation_mapping.update(
{
"zh": ["", "", "", ""],
"ja": ["", "", "", ""],
"ar": ["،", ".", "!", "؟"],
"fa": ["،", ".", "!", "؟"],
"el": [",", ".", "", ""],
"ti": [""],
"th": [" ", ",", ".", "!", "?"],
}
)
punctuation_words: List[StreamWordsModel] = []
for recognition_result in recognition_results:
punctuations = punctuation_mapping[query_data.lang]
for word in recognition_result.words:
for char in word.text:
if char in punctuations:
punctuation_words.append(word)
break
punctuation_start_times = list(
map(lambda x: x.start_time, punctuation_words)
)
punctuation_start_times = sorted(punctuation_start_times)
punctuation_end_times = list(map(lambda x: x.end_time, punctuation_words))
punctuation_end_times = sorted(punctuation_end_times)
voices = query_data.voice
label_len = len(voices)
pred_punctuation_num = len(punctuation_words)
label_punctuation_num = 0
for label_voice in voices:
punctuations = punctuation_mapping[query_data.lang]
for char in label_voice.answer:
if char in punctuations:
label_punctuation_num += 1
pred_sentence_punctuation_num = 0
label_setence_punctuation_num = label_len
for i, label_voice in enumerate(voices):
if i < label_len - 1:
label_left = label_voice.end
label_right = voices[i + 1].start
else:
label_left = label_voice.end - 0.7
label_right = label_voice.end + 0.7
left_in_pred = upper_bound(label_left, punctuation_start_times)
exist = False
if (
left_in_pred != -1
and punctuation_start_times[left_in_pred] <= label_right
):
exist = True
right_in_pred = lower_bound(label_right, punctuation_end_times)
if (
right_in_pred != -1
and punctuation_end_times[right_in_pred] >= label_left
):
exist = True
if exist:
pred_sentence_punctuation_num += 1
return (
pred_punctuation_num,
label_punctuation_num,
pred_sentence_punctuation_num,
label_setence_punctuation_num,
)
def upper_bound(x: float, lst: List[float]) -> int:
"""第一个 >= x 的元素的下标 没有返回-1"""
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] >= x:
ans = mid
right = mid - 1
else:
left = mid + 1
return ans
def lower_bound(x: float, lst: List[float]) -> int:
"""最后一个 <= x 的元素的下标 没有返回-1"""
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] <= x:
ans = mid
left = mid + 1
else:
right = mid - 1
return ans