Files
enginex-c_series-asr/utils/calculate.py

835 lines
29 KiB
Python
Raw Normal View History

2025-08-28 18:46:56 +08:00
import re
import time
import Levenshtein
from utils.tokenizer import Tokenizer
from typing import List, Tuple
from utils.model import SegmentModel
from utils.model import AudioItem
from utils.reader import read_data
from utils.logger import logger
from utils.model import VoiceSegment
from utils.model import WordModel
from difflib import SequenceMatcher
import logging
def calculate_punctuation_ratio(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
"""
计算acc
:param datas:
:return:
"""
total_standard_punctuation = 0
total_gen_punctuation = 0
for answer, results in datas:
# 计算 1-cer。
# 计算标点符号比例。
# 将所有的text组合起来与标答计算 1-cer
standard_text = ""
for item in answer.voice:
standard_text = standard_text + item.answer
gen_text = ""
for item in results:
gen_text = gen_text + item.text
total_standard_punctuation = total_standard_punctuation + count_punctuation(standard_text)
total_gen_punctuation = total_gen_punctuation + count_punctuation(gen_text)
punctuation_ratio = total_gen_punctuation / total_standard_punctuation
return punctuation_ratio
def calculate_acc(datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str) -> float:
"""
计算acc
:param datas:
:return:
"""
total_acc = 0
for answer, results in datas:
# 计算 1-cer。
# 计算标点符号比例。
# 将所有的text组合起来与标答计算 1-cer
standard_text = ""
for item in answer.voice:
standard_text = standard_text + item.answer
gen_text = ""
for item in results:
gen_text = gen_text + item.text
acc = cal_per_cer(gen_text, standard_text, language)
total_acc = total_acc + acc
acc = total_acc / len(datas)
return acc
def get_alignment_type(language: str):
chart_langs = ["zh", "ja", "ko", "th", "lo", "my", "km", "bo"] # 中文、日文、韩文、泰语、老挝语、缅甸语、高棉语、藏语
if language in chart_langs:
return "chart"
else:
return "word"
def cal_per_cer(text: str, answer: str, language: str):
if not answer:
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0否则为 1
text = remove_punctuation(text)
answer = remove_punctuation(answer)
text_chars = Tokenizer.norm_and_tokenize([text], language)[0]
answer_chars = Tokenizer.norm_and_tokenize([answer], language)[0]
# 如果答案为空,返回默认准确率
if not answer_chars:
return 0.0 # 或者 1.0,取决于你的设计需求
alignment_type = get_alignment_type(language)
if alignment_type == "chart":
text_chars = list(text)
answer_chars = list(answer)
ops = Levenshtein.editops(text_chars, answer_chars)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
else:
matcher = SequenceMatcher(None, text_chars, answer_chars)
insert = 0
delete = 0
replace = 0
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'replace':
replace += max(i2 - i1, j2 - j1)
elif tag == 'delete':
delete += (i2 - i1)
elif tag == 'insert':
insert += (j2 - j1)
cer = (insert + delete + replace) / len(answer_chars)
acc = 1 - cer
return acc
def cal_total_cer(samples: list):
"""
samples: List of tuples [(pred_text, ref_text), ...]
"""
total_insert = 0
total_delete = 0
total_replace = 0
total_ref_len = 0
for text, answer in samples:
if not answer:
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0否则为 1
text = remove_punctuation(text)
answer = remove_punctuation(answer)
text_chars = list(text)
answer_chars = list(answer)
ops = Levenshtein.editops(text_chars, answer_chars)
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
total_insert += insert
total_delete += delete
total_replace += replace
total_ref_len += len(answer_chars)
total_cer = (total_insert + total_delete + total_replace) / total_ref_len if total_ref_len > 0 else 0.0
total_acc = 1 - total_cer
return total_acc
def remove_punctuation(text: str) -> str:
# 去除中英文标点
return re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
def count_punctuation(text: str) -> int:
"""统计文本中的指定标点个数"""
return len(re.findall(r"[^\w\s\u4e00-\u9fa5]", text))
from typing import List, Optional, Tuple
def calculate_standard_sentence_delay(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
for audio_item, asr_results in datas:
if not audio_item.voice:
continue # 没有标答内容
#
audio_texts = []
asr_texts = []
ref = audio_item.voice[0] # 默认取第一个标答段
ref_end_ms = int(ref.end * 1000)
# 找出所有ASR中包含标答尾字的文本简化为包含标答最后一个字
target_char = ref.answer.strip()[-1] # 标答尾字
matching_results = [r for r in asr_results if target_char in r.text and r.words]
if not matching_results:
continue # 没有找到包含尾字的ASR段
# 找出这些ASR段中最后一个词的end_time最大值作为尾字时间
latest_word_time = max(word.end_time for r in matching_results for word in r.words)
delay = latest_word_time - ref_end_ms
print(audio_item)
print(asr_results)
return 0
def align_texts(ref_text: str, hyp_text: str) -> List[Tuple[Optional[int], Optional[int]]]:
"""
使用编辑距离计算两个字符串的字符对齐
返回[(ref_idx, hyp_idx), ...]
"""
ops = Levenshtein.editops(ref_text, hyp_text)
ref_len = len(ref_text)
hyp_len = len(hyp_text)
ref_idx = 0
hyp_idx = 0
alignment = []
for op, i, j in ops:
while ref_idx < i and hyp_idx < j:
alignment.append((ref_idx, hyp_idx))
ref_idx += 1
hyp_idx += 1
if op == "replace":
alignment.append((i, j))
ref_idx = i + 1
hyp_idx = j + 1
elif op == "delete":
alignment.append((i, None))
ref_idx = i + 1
elif op == "insert":
alignment.append((None, j))
hyp_idx = j + 1
while ref_idx < ref_len and hyp_idx < hyp_len:
alignment.append((ref_idx, hyp_idx))
ref_idx += 1
hyp_idx += 1
while ref_idx < ref_len:
alignment.append((ref_idx, None))
ref_idx += 1
while hyp_idx < hyp_len:
alignment.append((None, hyp_idx))
hyp_idx += 1
return alignment
def align_tokens(ref_text: List[str], hyp_text: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
"""
计算分词后的两个字符串的对齐
返回[(ref_idx, hyp_idx), ...]
"""
matcher = SequenceMatcher(None, ref_text, hyp_text)
alignment = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal' or tag == 'replace':
for r, h in zip(range(i1, i2), range(j1, j2)):
alignment.append((r, h))
elif tag == 'delete':
for r in range(i1, i2):
alignment.append((r, None))
elif tag == 'insert':
for h in range(j1, j2):
alignment.append((None, h))
return alignment
def find_tail_word_time(
ref_text: List[str],
pred_text: List[str],
merged_words: List[WordModel],
char2word_idx: List[int],
alignment: List[Tuple[Optional[int], Optional[int]]],
) -> Optional[WordModel]:
# alignment = align_texts(ref_text, merged_text)
"""
根据标答文本 ref_text 找尾字非标点
通过 alignment 找对应合成文本 pred_text 尾字索引
再通过 char2word_idx 找对应word索引返回对应的 WordModel
:param ref_text: 标答文本字符列表
:param pred_text: 合成文本字符列表
:param merged_words: 合成文本对应的WordModel列表
:param char2word_idx: 合成文本每个字符对应的WordModel索引
:param alignment: ref_text和pred_text的字符对齐列表 (ref_idx, hyp_idx)
:return: 对应尾字的WordModel None
"""
punct_set = set(",。!?、,.!?;")
# 1. 找到ref_text中最后一个非标点字符的索引 tail_ref_idx
tail_ref_idx = len(ref_text) - 1
while tail_ref_idx >= 0 and ref_text[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
# 全是标点,找不到尾字
return None
# 2. 找 alignment 中 ref_idx == tail_ref_idx 对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
if tail_hyp_idx is None:
# 没有对应的hyp_idx
return None
# 3. hyp_idx 超出范围
if tail_hyp_idx >= len(char2word_idx):
return None
# 4. 通过 char2word_idx 找对应 word 索引
word_index = char2word_idx[tail_hyp_idx]
if word_index >= len(merged_words):
return None
# 5. 返回对应的WordModel
return merged_words[word_index]
def find_head_word_time(
ref_text: List[str],
pred_text: List[str],
merged_words: List[WordModel],
char2word_idx: List[int],
alignment: List[Tuple[Optional[int], Optional[int]]],
) -> Optional[WordModel]:
"""
找标答首字在ASR中的start_time
参数
ref_text标答完整文本
merged_textASR合并后的完整文本逐字
merged_wordsASR合并的WordModel列表
char2word_idx字符到词的映射索引列表
返回
找到的首字对应词的start_time毫秒没找到返回None
"""
# alignment = align_texts(ref_text, merged_text)
ref_head_index = 0 # 首字索引固定0
for ref_idx, hyp_idx in alignment:
if ref_idx == ref_head_index and hyp_idx is not None:
if 0 <= hyp_idx < len(char2word_idx):
word_idx = char2word_idx[hyp_idx]
return merged_words[word_idx]
return None
def merge_asr_results(asr_list: List[SegmentModel]) -> Tuple[str, List[WordModel], List[int]]:
"""
合并多个 ASRResultModel 成一个大文本和 word 列表同时建立每个字符对应的 WordModel 索引
返回
- 合并文本 merged_text
- WordModel 列表 merged_words
- 每个字符所在 WordModel 的索引 char2word_idx
"""
# merged_text = ""
# merged_words = []
# char2word_idx = []
#
# for asr in asr_list:
# if not asr.text or not asr.words:
# continue
# merged_text += asr.text
# for word in asr.words:
# word.segment = asr
# merged_words.append(word)
# for ch in word.text:
# char2word_idx.append(len(merged_words) - 1)
# return merged_text, merged_words, char2word_idx
"""
合并多个 ASRResultModel 成一个大文本和 word 列表
去掉标点符号建立每个字符对应的 WordModel 索引
返回
- 去标点后的合并文本 merged_text
- WordModel 列表 merged_words包含标点
- 去标点后的每个字符对应 WordModel 的索引 char2word_idx
"""
punct_set = set(",。!?、,.!?;") # 需要过滤的标点集合
merged_text = ""
merged_words = []
char2word_idx = []
for asr in asr_list:
if not asr.text or not asr.words:
continue
merged_words_start_len = len(merged_words)
for word in asr.words:
word.segment = asr
merged_words.append(word)
# 遍历所有word拼接时去掉标点同时维护 char2word_idx
for idx_in_asr, word in enumerate(asr.words):
word_idx = merged_words_start_len + idx_in_asr
for ch in word.text:
if ch not in punct_set:
merged_text += ch
char2word_idx.append(word_idx)
return merged_text, merged_words, char2word_idx
def rebuild_char2word_idx(pred_tokens: List[str], merged_words: List[WordModel]) -> List[int]:
"""
重新构建 char2word_idx使其与 pred_tokens 一一对应
"""
char2word_idx = []
word_char_idx = 0
for word_idx, word in enumerate(merged_words):
for _ in word.text:
if word_char_idx < len(pred_tokens):
char2word_idx.append(word_idx)
word_char_idx += 1
return char2word_idx
def build_hyp_token_to_asr_chart_index(
hyp_tokens: List[str],
asr_words: List[WordModel]
) -> List[int]:
"""
建立从 hyp_token 索引到 asr_word 索引的映射
假设 asr_words text 组成 hyp_tokens 的连续子串简单匹配
"""
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
i_asr = 0
i_hyp = 0
while i_asr < len(asr_words) and i_hyp < len(hyp_tokens):
asr_word = asr_words[i_asr].text
length = len(asr_word)
# 拼接 hyp_tokens 从 i_hyp 开始的 length 个 token
hyp_substr = "".join(hyp_tokens[i_hyp:i_hyp + length])
if hyp_substr == asr_word:
# 匹配成功,建立映射
for k in range(i_hyp, i_hyp + length):
hyp_to_asr_word_idx[k] = i_asr
i_hyp += length
i_asr += 1
else:
# 如果不匹配,尝试扩大或缩小匹配长度(容错)
# 也可以根据具体情况改进此逻辑
# 这里简化处理跳过一个hyp token
i_hyp += 1
return hyp_to_asr_word_idx
import re
def normalize(text: str) -> str:
return re.sub(r"[^\w']+", '', text.lower()) # 去除非单词字符,保留撇号
def build_hyp_token_to_asr_word_index(hyp_tokens: List[str], asr_words: List[WordModel]) -> List[int]:
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
i_hyp, i_asr = 0, 0
while i_hyp < len(hyp_tokens) and i_asr < len(asr_words):
hyp_token = normalize(hyp_tokens[i_hyp])
asr_word = normalize(asr_words[i_asr].text)
# 匹配包含/前缀关系,提高鲁棒性
if hyp_token == asr_word or hyp_token in asr_word or asr_word in hyp_token:
hyp_to_asr_word_idx[i_hyp] = i_asr
i_hyp += 1
i_asr += 1
else:
i_hyp += 1
return hyp_to_asr_word_idx
def find_tail_word(
ref_tokens: List[str], # 参考文本token列表
hyp_tokens: List[str], # 预测文本token列表
alignment: List[Tuple[Optional[int], Optional[int]]], # (ref_idx, hyp_idx)对齐结果
hyp_to_asr_word_idx: dict,
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;")
) -> Optional[WordModel]:
"""
通过参考文本尾token定位对应预测token再映射到ASR词拿时间
"""
"""
找到 ASR 结果中对应预测文本最后一个有效对齐词 WordModeltail word
"""
# 1. 去掉 ref 尾部标点,找到 ref 尾词 index
tail_ref_idx = len(ref_tokens) - 1
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
return None
# 2. 在 alignment 中找到对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
# 3. 如果找不到,退一步找最后一个有匹配的 ref_idx
if tail_hyp_idx is None:
for ref_idx, hyp_idx in reversed(alignment):
if hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
return None
# 4. 映射到 ASR word index
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
return None
return asr_words[asr_word_idx]
def find_tail_word2(
ref_tokens: List[str], # 标答token列表
hyp_tokens: List[str], # 预测token列表
alignment: List[Tuple[Optional[int], Optional[int]]], # 对齐 (ref_idx, hyp_idx)
hyp_to_asr_word_idx: List[int], # hyp token 对应的 ASR word 索引
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;"),
enable_debug: bool = False
) -> Optional[WordModel]:
"""
找到 ASR 结果中对应预测文本最后一个有效对齐词 WordModeltail word
返回 None 表示没找到
"""
# Step 1. 找到 ref_tokens 中最后一个非标点的索引
tail_ref_idx = len(ref_tokens) - 1
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
tail_ref_idx -= 1
if tail_ref_idx < 0:
if enable_debug:
print("全是标点,尾字找不到")
return None
# Step 2. alignment 中查找 tail_ref_idx 对应的 hyp_idx
tail_hyp_idx = None
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == tail_ref_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
# Step 3. fallback如果找不到向前找最近一个非标点且能对齐的 ref_idx
fallback_idx = tail_ref_idx
while tail_hyp_idx is None and fallback_idx >= 0:
if ref_tokens[fallback_idx] not in punct_set:
for ref_idx, hyp_idx in reversed(alignment):
if ref_idx == fallback_idx and hyp_idx is not None:
tail_hyp_idx = hyp_idx
break
fallback_idx -= 1
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
if enable_debug:
print(f"tail_hyp_idx 无法找到或超出范围: {tail_hyp_idx}")
return None
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
if enable_debug:
print(f"asr_word_idx 无效: {asr_word_idx}")
return None
return asr_words[asr_word_idx]
def find_head_word(
ref_tokens: List[str],
hyp_tokens: List[str],
alignment: List[Tuple[Optional[int], Optional[int]]],
hyp_to_asr_word_idx: dict,
asr_words: List[WordModel],
punct_set: set = set(",。!?、,.!?;")
) -> Optional[WordModel]:
"""
通过参考文本开头第一个非标点token定位对应预测token再映射到ASR词拿时间
"""
# 1. 找到参考文本开头非标点索引
head_ref_idx = 0
while head_ref_idx < len(ref_tokens) and ref_tokens[head_ref_idx] in punct_set:
head_ref_idx += 1
if head_ref_idx >= len(ref_tokens):
return None
# 2. 找到 alignment 中对应的 hyp_idx
head_hyp_idx = None
for ref_idx, hyp_idx in alignment:
if ref_idx == head_ref_idx and hyp_idx is not None:
head_hyp_idx = hyp_idx
break
if head_hyp_idx is None or head_hyp_idx >= len(hyp_to_asr_word_idx):
return None
# 3. 映射到 asr_words 的索引
asr_word_idx = hyp_to_asr_word_idx[head_hyp_idx]
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
return None
return asr_words[asr_word_idx]
def calculate_sentence_delay(
datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str = "zh"
) -> (float, float, float):
"""
:param datas: 标答和模型结果
:return: 未找到尾字的比例修正0的数量平均延迟时间
"""
tail_offset_time = 0 # 尾字偏移时间
standard_offset_time = 0 # 尾字偏移时间
tail_not_found = 0 # 未找到尾字数量
tail_found = 0 # 找到尾字数量
standard_fix = 0
final_fix = 0
head_offset_time = 0 # 尾字偏移时间
final_offset_time = 0 # 尾字偏移时间
head_not_found = 0 # 未找到尾字数量
head_found = 0 # 找到尾字数量
for audio_item, asr_list in datas:
if not audio_item.voice:
continue
# (以防万一)将标答中所有的文本连起来,并将标答中最后一条信息的结束时间作为结束时间。
ref_text = ""
for voice in audio_item.voice:
ref_text = ref_text + voice.answer.strip()
if not ref_text:
continue
logger.debug(f"-=-=-=-=-=-=-=-=-=-=-=-=-=start-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
ref_end_ms = int(audio_item.voice[-1].end * 1000)
ref_start_ms = int(audio_item.voice[0].start * 1000)
# 录音所有的片段都过一下text的end和receive的end
# 统计定稿时间。 接收到segment的时间 - segment.end_time
#
# pred_text = ""
# asr_words: List[WordModel] = []
# for asr in asr_list:
# pred_text = pred_text + asr.text
# final_offset = asr.receive_time - asr.end_time
# asr_words = asr_words + asr.words
# for word in asr.words:
# word.segment = asr
# if final_offset < 0:
# final_fix = final_fix + 1
# # 统计被修复的数量
# final_offset = 0
# # 统计定稿偏移时间
# final_offset_time = final_offset_time + final_offset
pred_text = []
asr_words: List[WordModel] = []
temp_final_offset_time = 0
for asr in asr_list:
pred_text = pred_text + [word.text for word in asr.words]
final_offset = asr.receive_time - asr.end_time
logger.debug(f"asr.receive_time {asr.receive_time} , asr.end_time {asr.end_time} , final_offset {final_offset}")
asr_words = asr_words + asr.words
for word in asr.words:
word.segment = asr
if final_offset < 0:
final_fix = final_fix + 1
# 统计被修复的数量
final_offset = 0
# 统计定稿偏移时间
temp_final_offset_time = temp_final_offset_time + final_offset
final_offset_time = final_offset_time + temp_final_offset_time / len(asr_list)
# 处理模型给出的结果。
logger.debug(f"text: {ref_text},pred_text: {pred_text}")
# 计算对应关系
# pred_tokens 是与原文一致的只是可能多了几个为空的位置。需要打平为一维数组并记录对应的word的位置。
flat_pred_tokens = []
hyp_to_asr_word_idx = {} # key: flat_pred_token_index -> asr_word_index
alignment_type = get_alignment_type(language)
if alignment_type == "chart":
label_tokens = Tokenizer.tokenize([ref_text], language)[0]
pred_tokens = Tokenizer.tokenize(pred_text, language)
for asr_idx, token_group in enumerate(pred_tokens):
for token in token_group:
flat_pred_tokens.append(token)
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
alignment = align_texts(label_tokens, "".join(flat_pred_tokens))
else:
label_tokens = Tokenizer.norm_and_tokenize([ref_text], language)[0]
pred_tokens = Tokenizer.norm_and_tokenize(pred_text, language)
for asr_idx, token_group in enumerate(pred_tokens):
for token in token_group:
flat_pred_tokens.append(token)
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
alignment = align_tokens(label_tokens, flat_pred_tokens)
logger.debug(f"ref_tokens: {label_tokens}")
logger.debug(f"pred_tokens: {pred_tokens}")
logger.debug(f"alignment sample: {alignment[:30]}") # 只打印前30个避免日志过大
logger.debug(f"hyp_to_asr_word_idx: {hyp_to_asr_word_idx}")
head_word_info = find_head_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
if head_word_info is None:
# 统计没有找到首字的数量
head_not_found = head_not_found + 1
logger.debug(f"未找到首字")
else:
logger.debug(f"head_word: {head_word_info.text} ref_start_ms:{ref_start_ms}")
# 找到首字
# 统计首字偏移时间 首字在策略中出现的word的时间 - 标答start_time
head_offset_time = head_offset_time + abs(head_word_info.start_time - ref_start_ms)
# 统计找到首字的数量
head_found += 1
# 找尾字所在的模型返回words信息。
tail_word_info = find_tail_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
if tail_word_info is None:
# 没有找到尾字,记录数量
tail_not_found = tail_not_found + 1
logger.debug(f"未找到尾字")
else:
# 找到尾字了
logger.debug(f"tail_word: {tail_word_info.text} ref_end_ms: {ref_end_ms}")
# 统计尾字偏移时间 标答的end_time - 策略尾字所在word的end_time
tail_offset_time = abs(ref_end_ms - tail_word_info.end_time) + tail_offset_time
# 统计标答句延迟时间 策略尾字所在word的实际接收时间 - 标答句end时间
standard_offset = tail_word_info.segment.receive_time - ref_end_ms
logger.debug(f"tail_word_info.segment.receive_time {tail_word_info.segment.receive_time } , tail_word_info.end_time {tail_word_info.end_time} , ref_end_ms {ref_end_ms}")
# 如果小于0修正为0
if standard_offset < 0:
standard_offset = 0
# 统计被修正的数量
standard_fix = standard_fix + 1
standard_offset_time = standard_offset + standard_offset_time
# 统计找到尾字的数量
tail_found += 1
logger.info(f"-=-=-=-=-=-=-=-=-=-=-=-=-=end-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
logger.debug(
f"找到首字字数量: {head_found},未找到首字数量:{head_not_found},找到尾字数量: {tail_found},未找到尾字数量:{tail_not_found},修正标答偏移负数数量:{standard_fix},修正定稿偏移负数数量:{final_fix}")
logger.debug(
f"尾字偏移总时间:{tail_offset_time},标答句偏移总时间:{standard_offset_time},首字偏移总时间:{head_offset_time},定稿偏移总时间:{final_offset_time}")
#
# 统计平均值
head_not_found_ratio = head_not_found / (head_found + head_not_found)
tail_not_found_ratio = tail_not_found / (tail_found + tail_not_found)
average_tail_offset = tail_offset_time / tail_found / 1000
average_head_offset = head_offset_time / head_found / 1000
average_standard_offset = standard_offset_time / tail_found / 1000
average_final_offset = final_offset_time / tail_found / 1000
logger.info(
f"首字未找到比例:{head_not_found_ratio},尾字未找到比例:{tail_not_found_ratio},首字偏移时间:{average_head_offset},尾字偏移时间:{average_tail_offset},标答句偏移时间:{average_standard_offset},定稿偏移时间:{average_final_offset}")
return head_not_found_ratio, average_head_offset, tail_not_found_ratio, average_standard_offset, average_final_offset, average_tail_offset
if __name__ == '__main__':
checks = [
{
"type": "zh",
"ref": "今天天气真好",
"hyp": "今天真好"
},
{
"type": "zh",
"ref": "我喜欢吃苹果",
"hyp": "我很喜欢吃香蕉"
},
{
"type": "zh",
"ref": "我喜欢吃苹果",
"hyp": "我喜欢吃苹果"
},
{
"type": "en",
"ref": "I like to eat apples",
"hyp": "I really like eating apples"
},
{
"type": "en",
"ref": "She is going to the market",
"hyp": "She went market"
},
{
"type": "en",
"ref": "Hello world",
"hyp": "Hello world"
},
{
"type": "en",
"ref": "Good morning",
"hyp": "Bad night"
},
]
for check in checks:
ref = check.get("ref")
type = check.get("type")
hyp = check.get("hyp")
res1 = align_texts(ref, hyp)
res2 = align_tokens(list(ref), list(hyp))
from utils.tokenizer import Tokenizer
start = time.time()
tokens_pred = Tokenizer.norm_and_tokenize([ref], type)
print(time.time() - start)
start = time.time()
Tokenizer.norm_and_tokenize([ref + ref + ref], type)
print(time.time() - start)
tokens_label = Tokenizer.norm_and_tokenize([hyp], type)
print(tokens_pred)
print(tokens_label)
res3 = align_tokens(tokens_pred[0], tokens_label[0])
print(res1 == res2)
print(res1 == res3)