Files
enginex-mr_series-asr/utils/calculate.py
2025-08-20 14:29:42 +08:00

835 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import time
import Levenshtein
from utils.tokenizer import Tokenizer
from typing import List, Tuple
from utils.model import SegmentModel
from utils.model import AudioItem
from utils.reader import read_data
from utils.logger import logger
from utils.model import VoiceSegment
from utils.model import WordModel
from difflib import SequenceMatcher
import logging
def calculate_punctuation_ratio(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
"""
计算acc
:param datas:
:return:
"""
total_standard_punctuation = 0
total_gen_punctuation = 0
for answer, results in datas:
# 计算 1-cer。
# 计算标点符号比例。
# 将所有的text组合起来与标答计算 1-cer
standard_text = ""
for item in answer.voice:
standard_text = standard_text + item.answer
gen_text = ""
for item in results:
gen_text = gen_text + item.text
total_standard_punctuation = total_standard_punctuation + count_punctuation(standard_text)
total_gen_punctuation = total_gen_punctuation + count_punctuation(gen_text)
punctuation_ratio = total_gen_punctuation / total_standard_punctuation
return punctuation_ratio
def calculate_acc(datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str) -> float:
"""
计算acc
:param datas:
:return:
"""
total_acc = 0
for answer, results in datas:
# 计算 1-cer。
# 计算标点符号比例。
# 将所有的text组合起来与标答计算 1-cer
standard_text = ""
for item in answer.voice:
standard_text = standard_text + item.answer
gen_text = ""
for item in results:
gen_text = gen_text + item.text
acc = cal_per_cer(gen_text, standard_text, language)
total_acc = total_acc + acc
acc = total_acc / len(datas)
return acc
def get_alignment_type(language: str):
chart_langs = ["zh", "ja", "ko", "th", "lo", "my", "km", "bo"] # 中文、日文、韩文、泰语、老挝语、缅甸语、高棉语、藏语
if language in chart_langs:
return "chart"
else:
return "word"
def cal_per_cer(text: str, answer: str, language: str):
if not answer:
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0否则为 1
text = remove_punctuation(text)
answer = remove_punctuation(answer)
text_chars = Tokenizer.norm_and_tokenize([text], language)[0]
answer_chars = Tokenizer.norm_and_tokenize([answer], language)[0]
# 如果答案为空,返回默认准确率
if not answer_chars:
return 0.0 # 或者 1.0,取决于你的设计需求
alignment_type = get_alignment_type(language)
if alignment_type == "chart":
text_chars = list(text)
answer_chars = list(answer)
ops = Levenshtein.editops(text_chars, answer_chars)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
else:
matcher = SequenceMatcher(None, text_chars, answer_chars)
insert = 0
delete = 0
replace = 0
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'replace':
replace += max(i2 - i1, j2 - j1)
elif tag == 'delete':
delete += (i2 - i1)
elif tag == 'insert':
insert += (j2 - j1)
cer = (insert + delete + replace) / len(answer_chars)
acc = 1 - cer
return acc
def cal_total_cer(samples: list):
"""
samples: List of tuples [(pred_text, ref_text), ...]
"""
total_insert = 0
total_delete = 0
total_replace = 0
total_ref_len = 0
for text, answer in samples:
if not answer:
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0否则为 1
text = remove_punctuation(text)
answer = remove_punctuation(answer)
text_chars = list(text)
answer_chars = list(answer)
ops = Levenshtein.editops(text_chars, answer_chars)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
total_insert += insert
total_delete += delete
total_replace += replace
total_ref_len += len(answer_chars)
total_cer = (total_insert + total_delete + total_replace) / total_ref_len if total_ref_len > 0 else 0.0
total_acc = 1 - total_cer
return total_acc
def remove_punctuation(text: str) -> str:
# 去除中英文标点
return re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
def count_punctuation(text: str) -> int:
"""统计文本中的指定标点个数"""
return len(re.findall(r"[^\w\s\u4e00-\u9fa5]", text))
from typing import List, Optional, Tuple
def calculate_standard_sentence_delay(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
for audio_item, asr_results in datas:
if not audio_item.voice:
continue # 没有标答内容
#
audio_texts = []
asr_texts = []
ref = audio_item.voice[0] # 默认取第一个标答段
ref_end_ms = int(ref.end * 1000)
# 找出所有ASR中包含标答尾字的文本简化为包含标答最后一个字
target_char = ref.answer.strip()[-1] # 标答尾字
matching_results = [r for r in asr_results if target_char in r.text and r.words]
if not matching_results:
continue # 没有找到包含尾字的ASR段
# 找出这些ASR段中最后一个词的end_time最大值作为尾字时间
latest_word_time = max(word.end_time for r in matching_results for word in r.words)
delay = latest_word_time - ref_end_ms
print(audio_item)
print(asr_results)
return 0
def align_texts(ref_text: str, hyp_text: str) -> List[Tuple[Optional[int], Optional[int]]]:
"""
使用编辑距离计算两个字符串的字符对齐
返回:[(ref_idx, hyp_idx), ...]
"""
ops = Levenshtein.editops(ref_text, hyp_text)
ref_len = len(ref_text)
hyp_len = len(hyp_text)
ref_idx = 0
hyp_idx = 0
alignment = []
for op, i, j in ops:
while ref_idx < i and hyp_idx < j:
alignment.append((ref_idx, hyp_idx))
ref_idx += 1
hyp_idx += 1
if op == "replace":
alignment.append((i, j))
ref_idx = i + 1
hyp_idx = j + 1
elif op == "delete":
alignment.append((i, None))
ref_idx = i + 1
elif op == "insert":
alignment.append((None, j))
hyp_idx = j + 1
while ref_idx < ref_len and hyp_idx < hyp_len:
alignment.append((ref_idx, hyp_idx))
ref_idx += 1
hyp_idx += 1
while ref_idx < ref_len:
alignment.append((ref_idx, None))
ref_idx += 1
while hyp_idx < hyp_len:
alignment.append((None, hyp_idx))
hyp_idx += 1
return alignment
def align_tokens(ref_text: List[str], hyp_text: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
"""
计算分词后的两个字符串的对齐
返回:[(ref_idx, hyp_idx), ...]
"""
matcher = SequenceMatcher(None, ref_text, hyp_text)
alignment = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal' or tag == 'replace':
for r, h in zip(range(i1, i2), range(j1, j2)):
alignment.append((r, h))
elif tag == 'delete':
for r in range(i1, i2):
alignment.append((r, None))
elif tag == 'insert':
for h in range(j1, j2):
alignment.append((None, h))
return alignment
def find_tail_word_time(
ref_text: List[str],
pred_text: List[str],
merged_words: List[WordModel],
char2word_idx: List[int],
alignment: List[Tuple[Optional[int], Optional[int]]],
) -> Optional[WordModel]:
# alignment = align_texts(ref_text, merged_text)
"""
根据标答文本 ref_text 找尾字(非标点)
通过 alignment 找对应合成文本 pred_text 尾字索引
再通过 char2word_idx 找对应word索引返回对应的 WordModel
:param ref_text: 标答文本字符列表
:param pred_text: 合成文本字符列表
:param merged_words: 合成文本对应的WordModel列表
:param char2word_idx: 合成文本每个字符对应的WordModel索引
:param alignment: ref_text和pred_text的字符对齐列表 (ref_idx, hyp_idx)
:return: 对应尾字的WordModel 或 None
"""
punct_set = set(",。!?、,.!?;")
# 1. 找到ref_text中最后一个非标点字符的索引 tail_ref_idx
tail_ref_idx = len(ref_text) - 1
while tail_ref_idx >= 0 and ref_text[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
# 全是标点,找不到尾字
return None
# 2. 找 alignment 中 ref_idx == tail_ref_idx 对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
if tail_hyp_idx is None:
# 没有对应的hyp_idx
return None
# 3. hyp_idx 超出范围
if tail_hyp_idx >= len(char2word_idx):
return None
# 4. 通过 char2word_idx 找对应 word 索引
word_index = char2word_idx[tail_hyp_idx]
if word_index >= len(merged_words):
return None
# 5. 返回对应的WordModel
return merged_words[word_index]
def find_head_word_time(
ref_text: List[str],
pred_text: List[str],
merged_words: List[WordModel],
char2word_idx: List[int],
alignment: List[Tuple[Optional[int], Optional[int]]],
) -> Optional[WordModel]:
"""
找标答首字在ASR中的start_time
参数:
ref_text标答完整文本
merged_textASR合并后的完整文本逐字
merged_wordsASR合并的WordModel列表
char2word_idx字符到词的映射索引列表
返回:
找到的首字对应词的start_time毫秒没找到返回None
"""
# alignment = align_texts(ref_text, merged_text)
ref_head_index = 0 # 首字索引固定0
for ref_idx, hyp_idx in alignment:
if ref_idx == ref_head_index and hyp_idx is not None:
if 0 <= hyp_idx < len(char2word_idx):
word_idx = char2word_idx[hyp_idx]
return merged_words[word_idx]
return None
def merge_asr_results(asr_list: List[SegmentModel]) -> Tuple[str, List[WordModel], List[int]]:
"""
合并多个 ASRResultModel 成一个大文本和 word 列表,同时建立每个字符对应的 WordModel 索引
返回:
- 合并文本 merged_text
- WordModel 列表 merged_words
- 每个字符所在 WordModel 的索引 char2word_idx
"""
# merged_text = ""
# merged_words = []
# char2word_idx = []
#
# for asr in asr_list:
# if not asr.text or not asr.words:
# continue
# merged_text += asr.text
# for word in asr.words:
# word.segment = asr
# merged_words.append(word)
# for ch in word.text:
# char2word_idx.append(len(merged_words) - 1)
# return merged_text, merged_words, char2word_idx
"""
合并多个 ASRResultModel 成一个大文本和 word 列表,
去掉标点符号,建立每个字符对应的 WordModel 索引
返回:
- 去标点后的合并文本 merged_text
- WordModel 列表 merged_words包含标点
- 去标点后的每个字符对应 WordModel 的索引 char2word_idx
"""
punct_set = set(",。!?、,.!?;") # 需要过滤的标点集合
merged_text = ""
merged_words = []
char2word_idx = []
for asr in asr_list:
if not asr.text or not asr.words:
continue
merged_words_start_len = len(merged_words)
for word in asr.words:
word.segment = asr
merged_words.append(word)
# 遍历所有word拼接时去掉标点同时维护 char2word_idx
for idx_in_asr, word in enumerate(asr.words):
word_idx = merged_words_start_len + idx_in_asr
for ch in word.text:
if ch not in punct_set:
merged_text += ch
char2word_idx.append(word_idx)
return merged_text, merged_words, char2word_idx
def rebuild_char2word_idx(pred_tokens: List[str], merged_words: List[WordModel]) -> List[int]:
"""
重新构建 char2word_idx使其与 pred_tokens 一一对应
"""
char2word_idx = []
word_char_idx = 0
for word_idx, word in enumerate(merged_words):
for _ in word.text:
if word_char_idx < len(pred_tokens):
char2word_idx.append(word_idx)
word_char_idx += 1
return char2word_idx
def build_hyp_token_to_asr_chart_index(
hyp_tokens: List[str],
asr_words: List[WordModel]
) -> List[int]:
"""
建立从 hyp_token 索引到 asr_word 索引的映射
假设 asr_words 的 text 组成 hyp_tokens 的连续子串(简单匹配)
"""
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
i_asr = 0
i_hyp = 0
while i_asr < len(asr_words) and i_hyp < len(hyp_tokens):
asr_word = asr_words[i_asr].text
length = len(asr_word)
# 拼接 hyp_tokens 从 i_hyp 开始的 length 个 token
hyp_substr = "".join(hyp_tokens[i_hyp:i_hyp + length])
if hyp_substr == asr_word:
# 匹配成功,建立映射
for k in range(i_hyp, i_hyp + length):
hyp_to_asr_word_idx[k] = i_asr
i_hyp += length
i_asr += 1
else:
# 如果不匹配,尝试扩大或缩小匹配长度(容错)
# 也可以根据具体情况改进此逻辑
# 这里简化处理跳过一个hyp token
i_hyp += 1
return hyp_to_asr_word_idx
import re
def normalize(text: str) -> str:
return re.sub(r"[^\w']+", '', text.lower()) # 去除非单词字符,保留撇号
def build_hyp_token_to_asr_word_index(hyp_tokens: List[str], asr_words: List[WordModel]) -> List[int]:
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
i_hyp, i_asr = 0, 0
while i_hyp < len(hyp_tokens) and i_asr < len(asr_words):
hyp_token = normalize(hyp_tokens[i_hyp])
asr_word = normalize(asr_words[i_asr].text)
# 匹配包含/前缀关系,提高鲁棒性
if hyp_token == asr_word or hyp_token in asr_word or asr_word in hyp_token:
hyp_to_asr_word_idx[i_hyp] = i_asr
i_hyp += 1
i_asr += 1
else:
i_hyp += 1
return hyp_to_asr_word_idx
def find_tail_word(
ref_tokens: List[str], # 参考文本token列表
hyp_tokens: List[str], # 预测文本token列表
alignment: List[Tuple[Optional[int], Optional[int]]], # (ref_idx, hyp_idx)对齐结果
hyp_to_asr_word_idx: dict,
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;")
) -> Optional[WordModel]:
"""
通过参考文本尾token定位对应预测token再映射到ASR词拿时间
"""
"""
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModeltail word
"""
# 1. 去掉 ref 尾部标点,找到 ref 尾词 index
tail_ref_idx = len(ref_tokens) - 1
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
return None
# 2. 在 alignment 中找到对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
# 3. 如果找不到,退一步找最后一个有匹配的 ref_idx
if tail_hyp_idx is None:
for ref_idx, hyp_idx in reversed(alignment):
if hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
return None
# 4. 映射到 ASR word index
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
return None
return asr_words[asr_word_idx]
def find_tail_word2(
ref_tokens: List[str], # 标答token列表
hyp_tokens: List[str], # 预测token列表
alignment: List[Tuple[Optional[int], Optional[int]]], # 对齐 (ref_idx, hyp_idx)
hyp_to_asr_word_idx: List[int], # hyp token 对应的 ASR word 索引
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;"),
enable_debug: bool = False
) -> Optional[WordModel]:
"""
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModeltail word
返回 None 表示没找到
"""
# Step 1. 找到 ref_tokens 中最后一个非标点的索引
tail_ref_idx = len(ref_tokens) - 1
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
if enable_debug:
print("全是标点,尾字找不到")
return None
# Step 2. alignment 中查找 tail_ref_idx 对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
# Step 3. fallback如果找不到向前找最近一个非标点且能对齐的 ref_idx
fallback_idx = tail_ref_idx
while tail_hyp_idx is None and fallback_idx >= 0:
if ref_tokens[fallback_idx] not in punct_set:
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == fallback_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
fallback_idx -= 1
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
if enable_debug:
print(f"tail_hyp_idx 无法找到或超出范围: {tail_hyp_idx}")
return None
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
if enable_debug:
print(f"asr_word_idx 无效: {asr_word_idx}")
return None
return asr_words[asr_word_idx]
def find_head_word(
ref_tokens: List[str],
hyp_tokens: List[str],
alignment: List[Tuple[Optional[int], Optional[int]]],
hyp_to_asr_word_idx: dict,
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;")
) -> Optional[WordModel]:
"""
通过参考文本开头第一个非标点token定位对应预测token再映射到ASR词拿时间
"""
# 1. 找到参考文本开头非标点索引
head_ref_idx = 0
while head_ref_idx < len(ref_tokens) and ref_tokens[head_ref_idx] in punct_set:
head_ref_idx += 1
if head_ref_idx >= len(ref_tokens):
return None
# 2. 找到 alignment 中对应的 hyp_idx
head_hyp_idx = None
for ref_idx, hyp_idx in alignment:
if ref_idx == head_ref_idx and hyp_idx is not None:
head_hyp_idx = hyp_idx
break
if head_hyp_idx is None or head_hyp_idx >= len(hyp_to_asr_word_idx):
return None
# 3. 映射到 asr_words 的索引
asr_word_idx = hyp_to_asr_word_idx[head_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
return None
return asr_words[asr_word_idx]
def calculate_sentence_delay(
datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str = "zh"
) -> (float, float, float):
"""
:param datas: 标答和模型结果
:return: 未找到尾字的比例修正0的数量平均延迟时间。
"""
tail_offset_time = 0 # 尾字偏移时间
standard_offset_time = 0 # 尾字偏移时间
tail_not_found = 0 # 未找到尾字数量
tail_found = 0 # 找到尾字数量
standard_fix = 0
final_fix = 0
head_offset_time = 0 # 尾字偏移时间
final_offset_time = 0 # 尾字偏移时间
head_not_found = 0 # 未找到尾字数量
head_found = 0 # 找到尾字数量
for audio_item, asr_list in datas:
if not audio_item.voice:
continue
# (以防万一)将标答中所有的文本连起来,并将标答中最后一条信息的结束时间作为结束时间。
ref_text = ""
for voice in audio_item.voice:
ref_text = ref_text + voice.answer.strip()
if not ref_text:
continue
logger.debug(f"-=-=-=-=-=-=-=-=-=-=-=-=-=start-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
ref_end_ms = int(audio_item.voice[-1].end * 1000)
ref_start_ms = int(audio_item.voice[0].start * 1000)
# 录音所有的片段都过一下text的end和receive的end
# 统计定稿时间。 接收到segment的时间 - segment.end_time
#
# pred_text = ""
# asr_words: List[WordModel] = []
# for asr in asr_list:
# pred_text = pred_text + asr.text
# final_offset = asr.receive_time - asr.end_time
# asr_words = asr_words + asr.words
# for word in asr.words:
# word.segment = asr
# if final_offset < 0:
# final_fix = final_fix + 1
# # 统计被修复的数量
# final_offset = 0
# # 统计定稿偏移时间
# final_offset_time = final_offset_time + final_offset
pred_text = []
asr_words: List[WordModel] = []
temp_final_offset_time = 0
for asr in asr_list:
pred_text = pred_text + [word.text for word in asr.words]
final_offset = asr.receive_time - asr.end_time
logger.debug(f"asr.receive_time {asr.receive_time} , asr.end_time {asr.end_time} , final_offset {final_offset}")
asr_words = asr_words + asr.words
for word in asr.words:
word.segment = asr
if final_offset < 0:
final_fix = final_fix + 1
# 统计被修复的数量
final_offset = 0
# 统计定稿偏移时间
temp_final_offset_time = temp_final_offset_time + final_offset
final_offset_time = final_offset_time + temp_final_offset_time / len(asr_list)
# 处理模型给出的结果。
logger.debug(f"text: {ref_text},pred_text: {pred_text}")
# 计算对应关系
# pred_tokens 是与原文一致的只是可能多了几个为空的位置。需要打平为一维数组并记录对应的word的位置。
flat_pred_tokens = []
hyp_to_asr_word_idx = {} # key: flat_pred_token_index -> asr_word_index
alignment_type = get_alignment_type(language)
if alignment_type == "chart":
label_tokens = Tokenizer.tokenize([ref_text], language)[0]
pred_tokens = Tokenizer.tokenize(pred_text, language)
for asr_idx, token_group in enumerate(pred_tokens):
for token in token_group:
flat_pred_tokens.append(token)
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
alignment = align_texts(label_tokens, "".join(flat_pred_tokens))
else:
label_tokens = Tokenizer.norm_and_tokenize([ref_text], language)[0]
pred_tokens = Tokenizer.norm_and_tokenize(pred_text, language)
for asr_idx, token_group in enumerate(pred_tokens):
for token in token_group:
flat_pred_tokens.append(token)
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
alignment = align_tokens(label_tokens, flat_pred_tokens)
logger.debug(f"ref_tokens: {label_tokens}")
logger.debug(f"pred_tokens: {pred_tokens}")
logger.debug(f"alignment sample: {alignment[:30]}") # 只打印前30个避免日志过大
logger.debug(f"hyp_to_asr_word_idx: {hyp_to_asr_word_idx}")
head_word_info = find_head_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
if head_word_info is None:
# 统计没有找到首字的数量
head_not_found = head_not_found + 1
logger.debug(f"未找到首字")
else:
logger.debug(f"head_word: {head_word_info.text} ref_start_ms:{ref_start_ms}")
# 找到首字
# 统计首字偏移时间 首字在策略中出现的word的时间 - 标答start_time
head_offset_time = head_offset_time + abs(head_word_info.start_time - ref_start_ms)
# 统计找到首字的数量
head_found += 1
# 找尾字所在的模型返回words信息。
tail_word_info = find_tail_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
if tail_word_info is None:
# 没有找到尾字,记录数量
tail_not_found = tail_not_found + 1
logger.debug(f"未找到尾字")
else:
# 找到尾字了
logger.debug(f"tail_word: {tail_word_info.text} ref_end_ms: {ref_end_ms}")
# 统计尾字偏移时间 标答的end_time - 策略尾字所在word的end_time
tail_offset_time = abs(ref_end_ms - tail_word_info.end_time) + tail_offset_time
# 统计标答句延迟时间 策略尾字所在word的实际接收时间 - 标答句end时间
standard_offset = tail_word_info.segment.receive_time - ref_end_ms
logger.debug(f"tail_word_info.segment.receive_time {tail_word_info.segment.receive_time } , tail_word_info.end_time {tail_word_info.end_time} , ref_end_ms {ref_end_ms}")
# 如果小于0修正为0
if standard_offset < 0:
standard_offset = 0
# 统计被修正的数量
standard_fix = standard_fix + 1
standard_offset_time = standard_offset + standard_offset_time
# 统计找到尾字的数量
tail_found += 1
logger.info(f"-=-=-=-=-=-=-=-=-=-=-=-=-=end-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
logger.debug(
f"找到首字字数量: {head_found},未找到首字数量:{head_not_found},找到尾字数量: {tail_found},未找到尾字数量:{tail_not_found},修正标答偏移负数数量:{standard_fix},修正定稿偏移负数数量:{final_fix}")
logger.debug(
f"尾字偏移总时间:{tail_offset_time},标答句偏移总时间:{standard_offset_time},首字偏移总时间:{head_offset_time},定稿偏移总时间:{final_offset_time}")
#
# 统计平均值
head_not_found_ratio = head_not_found / (head_found + head_not_found)
tail_not_found_ratio = tail_not_found / (tail_found + tail_not_found)
average_tail_offset = tail_offset_time / tail_found / 1000
average_head_offset = head_offset_time / head_found / 1000
average_standard_offset = standard_offset_time / tail_found / 1000
average_final_offset = final_offset_time / tail_found / 1000
logger.info(
f"首字未找到比例:{head_not_found_ratio},尾字未找到比例:{tail_not_found_ratio},首字偏移时间:{average_head_offset},尾字偏移时间:{average_tail_offset},标答句偏移时间:{average_standard_offset},定稿偏移时间:{average_final_offset}")
return head_not_found_ratio, average_head_offset, tail_not_found_ratio, average_standard_offset, average_final_offset, average_tail_offset
if __name__ == '__main__':
checks = [
{
"type": "zh",
"ref": "今天天气真好",
"hyp": "今天真好"
},
{
"type": "zh",
"ref": "我喜欢吃苹果",
"hyp": "我很喜欢吃香蕉"
},
{
"type": "zh",
"ref": "我喜欢吃苹果",
"hyp": "我喜欢吃苹果"
},
{
"type": "en",
"ref": "I like to eat apples",
"hyp": "I really like eating apples"
},
{
"type": "en",
"ref": "She is going to the market",
"hyp": "She went market"
},
{
"type": "en",
"ref": "Hello world",
"hyp": "Hello world"
},
{
"type": "en",
"ref": "Good morning",
"hyp": "Bad night"
},
]
for check in checks:
ref = check.get("ref")
type = check.get("type")
hyp = check.get("hyp")
res1 = align_texts(ref, hyp)
res2 = align_tokens(list(ref), list(hyp))
from utils.tokenizer import Tokenizer
start = time.time()
tokens_pred = Tokenizer.norm_and_tokenize([ref], type)
print(time.time() - start)
start = time.time()
Tokenizer.norm_and_tokenize([ref + ref + ref], type)
print(time.time() - start)
tokens_label = Tokenizer.norm_and_tokenize([hyp], type)
print(tokens_pred)
print(tokens_label)
res3 = align_tokens(tokens_pred[0], tokens_label[0])
print(res1 == res2)
print(res1 == res3)