294 lines
11 KiB
Python
294 lines
11 KiB
Python
from collections import defaultdict
|
||
from copy import deepcopy
|
||
from itertools import chain
|
||
from typing import Dict, List, Tuple
|
||
|
||
import Levenshtein
|
||
|
||
from schemas.dataset import QueryData
|
||
from schemas.stream import StreamDataModel, StreamWordsModel
|
||
from utils.metrics import Tokenizer
|
||
from utils.metrics_plus import replace_general_punc
|
||
from utils.tokenizer import TOKENIZER_MAPPING
|
||
|
||
|
||
def evaluate_editops(
|
||
query_data: QueryData, recognition_results: List[StreamDataModel]
|
||
) -> Tuple[float, int, Dict[int, int], Dict[int, int], float, float]:
|
||
"""返回cer 句子总数 首字对齐情况 尾字对齐情况 首字时间差值和 尾字时间差值和
|
||
对齐情况为 时间差值->对齐数"""
|
||
recognition_results = deepcopy(recognition_results)
|
||
lang = query_data.lang
|
||
voices = query_data.voice
|
||
sentences_pred = [
|
||
recognition_result.text for recognition_result in recognition_results
|
||
]
|
||
sentences_label = [item.answer for item in voices]
|
||
|
||
tokenizer_type = TOKENIZER_MAPPING[lang]
|
||
sentences_pred = replace_general_punc(sentences_pred, tokenizer_type)
|
||
sentences_label = replace_general_punc(sentences_label, tokenizer_type)
|
||
|
||
# norm & tokenize
|
||
tokens_pred = Tokenizer.norm_and_tokenize(sentences_pred, lang)
|
||
tokens_label = Tokenizer.norm_and_tokenize(sentences_label, lang)
|
||
|
||
normed_words = []
|
||
for recognition_result in recognition_results:
|
||
words = list(map(lambda x: x.text, recognition_result.words))
|
||
normed_words.extend(words)
|
||
normed_words = replace_general_punc(normed_words, tokenizer_type)
|
||
normed_words = Tokenizer.norm(normed_words, lang)
|
||
|
||
# 预测中的结果进行相同的norm和tokenize操作
|
||
normed_word_index = 0
|
||
for recognition_result in recognition_results:
|
||
next_index = normed_word_index + len(recognition_result.words)
|
||
tokens_words = Tokenizer.tokenize(
|
||
normed_words[normed_word_index:next_index], lang
|
||
)
|
||
normed_word_index = next_index
|
||
stream_words: List[StreamWordsModel] = []
|
||
# 将原words进行norm和tokenize操作后赋值为对应原word的时间
|
||
for raw_stream_word, tokens_word in zip(
|
||
recognition_result.words, tokens_words
|
||
):
|
||
for word in tokens_word:
|
||
stream_words.append(
|
||
StreamWordsModel(
|
||
text=word,
|
||
start_time=raw_stream_word.start_time,
|
||
end_time=raw_stream_word.end_time,
|
||
)
|
||
)
|
||
recognition_result.words = stream_words
|
||
|
||
# 将words对应上对分词后的词,从而使得分词后的词有时间
|
||
pred_word_time: List[StreamWordsModel] = []
|
||
for token_pred, recognition_result in zip(tokens_pred, recognition_results):
|
||
word_index = 0
|
||
for word in recognition_result.words:
|
||
try:
|
||
token_index = token_pred.index(word.text, word_index)
|
||
for i in range(word_index, token_index + 1):
|
||
pred_word_time.append(
|
||
StreamWordsModel(
|
||
text=token_pred[i],
|
||
start_time=word.start_time,
|
||
end_time=word.end_time,
|
||
)
|
||
)
|
||
word_index = token_index + 1
|
||
except ValueError:
|
||
pass
|
||
if len(recognition_result.words) > 0:
|
||
word = recognition_result.words[-1]
|
||
start_time = word.start_time
|
||
end_time = word.end_time
|
||
else:
|
||
start_time = recognition_result.start_time
|
||
end_time = recognition_result.end_time
|
||
for i in range(word_index, len(token_pred)):
|
||
pred_word_time.append(
|
||
StreamWordsModel(
|
||
text=token_pred[i],
|
||
start_time=start_time,
|
||
end_time=end_time,
|
||
)
|
||
)
|
||
|
||
# 记录label每句话的首字尾字对应分词后的位置
|
||
index = 0
|
||
label_firstword_index: List[int] = []
|
||
label_lastword_index: List[int] = []
|
||
for token_label in tokens_label:
|
||
label_firstword_index.append(index)
|
||
index += len(token_label)
|
||
label_lastword_index.append(index - 1)
|
||
|
||
# cer
|
||
flat_tokens_pred = list(chain(*tokens_pred))
|
||
flat_tokens_label = list(chain(*tokens_label))
|
||
ops = Levenshtein.editops(flat_tokens_pred, flat_tokens_label)
|
||
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||
cer = (insert + delete + replace) / len(flat_tokens_label)
|
||
|
||
# 计算每个token在编辑后的下标位置
|
||
pred_offset = [0] * (len(flat_tokens_pred) + 1)
|
||
label_offset = [0] * (len(flat_tokens_label) + 1)
|
||
for op in ops:
|
||
if op[0] == "insert":
|
||
pred_offset[op[1]] += 1
|
||
elif op[0] == "delete":
|
||
label_offset[op[2]] += 1
|
||
pred_indexs = [pred_offset[0]]
|
||
for i in range(1, len(flat_tokens_pred)):
|
||
pred_indexs.append(pred_indexs[i - 1] + pred_offset[i] + 1)
|
||
label_indexs = [label_offset[0]]
|
||
for i in range(1, len(flat_tokens_label)):
|
||
label_indexs.append(label_indexs[i - 1] + label_offset[i] + 1)
|
||
|
||
# 计算每个label中首字和尾字对应的时间
|
||
align_start = {100: 0, 200: 0, 300: 0, 500: 0}
|
||
align_end = {100: 0, 200: 0, 300: 0, 500: 0}
|
||
first_word_distance_sum = 0.0
|
||
last_word_distance_sum = 0.0
|
||
for firstword_index, lastword_index, voice in zip(
|
||
label_firstword_index, label_lastword_index, voices
|
||
):
|
||
label_index = label_indexs[firstword_index]
|
||
label_in_pred_index = upper_bound(label_index, pred_indexs)
|
||
if label_in_pred_index != -1:
|
||
distance = abs(
|
||
voice.start - pred_word_time[label_in_pred_index].start_time
|
||
)
|
||
if label_in_pred_index > 0:
|
||
distance = min(
|
||
distance,
|
||
abs(
|
||
voice.start
|
||
- pred_word_time[label_in_pred_index - 1].start_time
|
||
),
|
||
)
|
||
else:
|
||
distance = abs(voice.start - pred_word_time[-1].start_time)
|
||
for limit in align_start.keys():
|
||
if distance <= limit / 1000:
|
||
align_start[limit] += 1
|
||
first_word_distance_sum += distance
|
||
|
||
label_index = label_indexs[lastword_index]
|
||
label_in_pred_index = lower_bound(label_index, pred_indexs)
|
||
if label_in_pred_index != -1:
|
||
distance = abs(
|
||
voice.end - pred_word_time[label_in_pred_index].end_time
|
||
)
|
||
if label_in_pred_index < len(pred_word_time) - 1:
|
||
distance = min(
|
||
distance,
|
||
abs(
|
||
voice.end
|
||
- pred_word_time[label_in_pred_index + 1].end_time
|
||
),
|
||
)
|
||
else:
|
||
distance = abs(voice.end - pred_word_time[0].end_time)
|
||
for limit in align_end.keys():
|
||
if distance <= limit / 1000:
|
||
align_end[limit] += 1
|
||
last_word_distance_sum += distance
|
||
return (
|
||
cer,
|
||
len(voices),
|
||
align_start,
|
||
align_end,
|
||
first_word_distance_sum,
|
||
last_word_distance_sum,
|
||
)
|
||
|
||
|
||
def evaluate_punctuation(
|
||
query_data: QueryData, recognition_results: List[StreamDataModel]
|
||
) -> Tuple[int, int, int, int]:
|
||
"""评估标点符号指标 返回预测中标点数 label中标点数 预测中句子标点数 label中句子标点数"""
|
||
punctuation_mapping = defaultdict(lambda: [",", ".", "!", "?"])
|
||
punctuation_mapping.update(
|
||
{
|
||
"zh": [",", "。", "!", "?"],
|
||
"ja": ["、", "。", "!", "?"],
|
||
"ar": ["،", ".", "!", "؟"],
|
||
"fa": ["،", ".", "!", "؟"],
|
||
"el": [",", ".", "!", ";"],
|
||
"ti": ["།"],
|
||
"th": [" ", ",", ".", "!", "?"],
|
||
}
|
||
)
|
||
|
||
punctuation_words: List[StreamWordsModel] = []
|
||
for recognition_result in recognition_results:
|
||
punctuations = punctuation_mapping[query_data.lang]
|
||
for word in recognition_result.words:
|
||
for char in word.text:
|
||
if char in punctuations:
|
||
punctuation_words.append(word)
|
||
break
|
||
punctuation_start_times = list(
|
||
map(lambda x: x.start_time, punctuation_words)
|
||
)
|
||
punctuation_start_times = sorted(punctuation_start_times)
|
||
punctuation_end_times = list(map(lambda x: x.end_time, punctuation_words))
|
||
punctuation_end_times = sorted(punctuation_end_times)
|
||
|
||
voices = query_data.voice
|
||
label_len = len(voices)
|
||
pred_punctuation_num = len(punctuation_words)
|
||
label_punctuation_num = 0
|
||
for label_voice in voices:
|
||
punctuations = punctuation_mapping[query_data.lang]
|
||
for char in label_voice.answer:
|
||
if char in punctuations:
|
||
label_punctuation_num += 1
|
||
|
||
pred_sentence_punctuation_num = 0
|
||
label_setence_punctuation_num = label_len
|
||
for i, label_voice in enumerate(voices):
|
||
if i < label_len - 1:
|
||
label_left = label_voice.end
|
||
label_right = voices[i + 1].start
|
||
else:
|
||
label_left = label_voice.end - 0.7
|
||
label_right = label_voice.end + 0.7
|
||
|
||
left_in_pred = upper_bound(label_left, punctuation_start_times)
|
||
exist = False
|
||
if (
|
||
left_in_pred != -1
|
||
and punctuation_start_times[left_in_pred] <= label_right
|
||
):
|
||
exist = True
|
||
right_in_pred = lower_bound(label_right, punctuation_end_times)
|
||
if (
|
||
right_in_pred != -1
|
||
and punctuation_end_times[right_in_pred] >= label_left
|
||
):
|
||
exist = True
|
||
|
||
if exist:
|
||
pred_sentence_punctuation_num += 1
|
||
return (
|
||
pred_punctuation_num,
|
||
label_punctuation_num,
|
||
pred_sentence_punctuation_num,
|
||
label_setence_punctuation_num,
|
||
)
|
||
|
||
|
||
def upper_bound(x: float, lst: List[float]) -> int:
|
||
"""第一个 >= x 的元素的下标 没有返回-1"""
|
||
ans = -1
|
||
left, right = 0, len(lst) - 1
|
||
while left <= right:
|
||
mid = (left + right) // 2
|
||
if lst[mid] >= x:
|
||
ans = mid
|
||
right = mid - 1
|
||
else:
|
||
left = mid + 1
|
||
return ans
|
||
|
||
|
||
def lower_bound(x: float, lst: List[float]) -> int:
|
||
"""最后一个 <= x 的元素的下标 没有返回-1"""
|
||
ans = -1
|
||
left, right = 0, len(lst) - 1
|
||
while left <= right:
|
||
mid = (left + right) // 2
|
||
if lst[mid] <= x:
|
||
ans = mid
|
||
left = mid + 1
|
||
else:
|
||
right = mid - 1
|
||
return ans
|