835 lines
29 KiB
Python
835 lines
29 KiB
Python
import re
|
||
import time
|
||
|
||
import Levenshtein
|
||
from utils.tokenizer import Tokenizer
|
||
from typing import List, Tuple
|
||
from utils.model import SegmentModel
|
||
from utils.model import AudioItem
|
||
from utils.reader import read_data
|
||
from utils.logger import logger
|
||
from utils.model import VoiceSegment
|
||
from utils.model import WordModel
|
||
from difflib import SequenceMatcher
|
||
import logging
|
||
|
||
|
||
def calculate_punctuation_ratio(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
|
||
"""
|
||
计算acc
|
||
:param datas:
|
||
:return:
|
||
"""
|
||
total_standard_punctuation = 0
|
||
total_gen_punctuation = 0
|
||
for answer, results in datas:
|
||
# 计算 1-cer。
|
||
# 计算标点符号比例。
|
||
# 将所有的text组合起来与标答计算 1-cer
|
||
standard_text = ""
|
||
for item in answer.voice:
|
||
standard_text = standard_text + item.answer
|
||
gen_text = ""
|
||
for item in results:
|
||
gen_text = gen_text + item.text
|
||
|
||
total_standard_punctuation = total_standard_punctuation + count_punctuation(standard_text)
|
||
total_gen_punctuation = total_gen_punctuation + count_punctuation(gen_text)
|
||
|
||
punctuation_ratio = total_gen_punctuation / total_standard_punctuation
|
||
return punctuation_ratio
|
||
|
||
|
||
def calculate_acc(datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str) -> float:
|
||
"""
|
||
计算acc
|
||
:param datas:
|
||
:return:
|
||
"""
|
||
total_acc = 0
|
||
for answer, results in datas:
|
||
# 计算 1-cer。
|
||
# 计算标点符号比例。
|
||
# 将所有的text组合起来与标答计算 1-cer
|
||
standard_text = ""
|
||
for item in answer.voice:
|
||
standard_text = standard_text + item.answer
|
||
gen_text = ""
|
||
for item in results:
|
||
gen_text = gen_text + item.text
|
||
acc = cal_per_cer(gen_text, standard_text, language)
|
||
total_acc = total_acc + acc
|
||
acc = total_acc / len(datas)
|
||
return acc
|
||
|
||
|
||
def get_alignment_type(language: str):
|
||
chart_langs = ["zh", "ja", "ko", "th", "lo", "my", "km", "bo"] # 中文、日文、韩文、泰语、老挝语、缅甸语、高棉语、藏语
|
||
if language in chart_langs:
|
||
return "chart"
|
||
else:
|
||
return "word"
|
||
|
||
|
||
def cal_per_cer(text: str, answer: str, language: str):
|
||
if not answer:
|
||
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1
|
||
|
||
text = remove_punctuation(text)
|
||
answer = remove_punctuation(answer)
|
||
|
||
text_chars = Tokenizer.norm_and_tokenize([text], language)[0]
|
||
answer_chars = Tokenizer.norm_and_tokenize([answer], language)[0]
|
||
|
||
# 如果答案为空,返回默认准确率
|
||
if not answer_chars:
|
||
return 0.0 # 或者 1.0,取决于你的设计需求
|
||
|
||
alignment_type = get_alignment_type(language)
|
||
if alignment_type == "chart":
|
||
text_chars = list(text)
|
||
answer_chars = list(answer)
|
||
ops = Levenshtein.editops(text_chars, answer_chars)
|
||
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||
else:
|
||
matcher = SequenceMatcher(None, text_chars, answer_chars)
|
||
|
||
insert = 0
|
||
delete = 0
|
||
replace = 0
|
||
|
||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||
if tag == 'replace':
|
||
replace += max(i2 - i1, j2 - j1)
|
||
elif tag == 'delete':
|
||
delete += (i2 - i1)
|
||
elif tag == 'insert':
|
||
insert += (j2 - j1)
|
||
|
||
cer = (insert + delete + replace) / len(answer_chars)
|
||
acc = 1 - cer
|
||
return acc
|
||
|
||
|
||
def cal_total_cer(samples: list):
|
||
"""
|
||
samples: List of tuples [(pred_text, ref_text), ...]
|
||
"""
|
||
total_insert = 0
|
||
total_delete = 0
|
||
total_replace = 0
|
||
total_ref_len = 0
|
||
|
||
for text, answer in samples:
|
||
|
||
if not answer:
|
||
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1
|
||
|
||
text = remove_punctuation(text)
|
||
answer = remove_punctuation(answer)
|
||
|
||
text_chars = list(text)
|
||
answer_chars = list(answer)
|
||
|
||
ops = Levenshtein.editops(text_chars, answer_chars)
|
||
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||
|
||
total_insert += insert
|
||
total_delete += delete
|
||
total_replace += replace
|
||
total_ref_len += len(answer_chars)
|
||
|
||
total_cer = (total_insert + total_delete + total_replace) / total_ref_len if total_ref_len > 0 else 0.0
|
||
total_acc = 1 - total_cer
|
||
return total_acc
|
||
|
||
|
||
def remove_punctuation(text: str) -> str:
|
||
# 去除中英文标点
|
||
return re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
|
||
|
||
|
||
def count_punctuation(text: str) -> int:
|
||
"""统计文本中的指定标点个数"""
|
||
return len(re.findall(r"[^\w\s\u4e00-\u9fa5]", text))
|
||
|
||
|
||
from typing import List, Optional, Tuple
|
||
|
||
|
||
def calculate_standard_sentence_delay(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
|
||
for audio_item, asr_results in datas:
|
||
if not audio_item.voice:
|
||
continue # 没有标答内容
|
||
#
|
||
audio_texts = []
|
||
asr_texts = []
|
||
|
||
ref = audio_item.voice[0] # 默认取第一个标答段
|
||
ref_end_ms = int(ref.end * 1000)
|
||
|
||
# 找出所有ASR中包含标答尾字的文本(简化为包含标答最后一个字)
|
||
target_char = ref.answer.strip()[-1] # 标答尾字
|
||
matching_results = [r for r in asr_results if target_char in r.text and r.words]
|
||
|
||
if not matching_results:
|
||
continue # 没有找到包含尾字的ASR段
|
||
|
||
# 找出这些ASR段中最后一个词的end_time,最大值作为尾字时间
|
||
latest_word_time = max(word.end_time for r in matching_results for word in r.words)
|
||
|
||
delay = latest_word_time - ref_end_ms
|
||
|
||
print(audio_item)
|
||
print(asr_results)
|
||
return 0
|
||
|
||
|
||
def align_texts(ref_text: str, hyp_text: str) -> List[Tuple[Optional[int], Optional[int]]]:
|
||
"""
|
||
使用编辑距离计算两个字符串的字符对齐
|
||
返回:[(ref_idx, hyp_idx), ...]
|
||
"""
|
||
ops = Levenshtein.editops(ref_text, hyp_text)
|
||
ref_len = len(ref_text)
|
||
hyp_len = len(hyp_text)
|
||
ref_idx = 0
|
||
hyp_idx = 0
|
||
alignment = []
|
||
|
||
for op, i, j in ops:
|
||
while ref_idx < i and hyp_idx < j:
|
||
alignment.append((ref_idx, hyp_idx))
|
||
ref_idx += 1
|
||
hyp_idx += 1
|
||
if op == "replace":
|
||
alignment.append((i, j))
|
||
ref_idx = i + 1
|
||
hyp_idx = j + 1
|
||
elif op == "delete":
|
||
alignment.append((i, None))
|
||
ref_idx = i + 1
|
||
elif op == "insert":
|
||
alignment.append((None, j))
|
||
hyp_idx = j + 1
|
||
|
||
while ref_idx < ref_len and hyp_idx < hyp_len:
|
||
alignment.append((ref_idx, hyp_idx))
|
||
ref_idx += 1
|
||
hyp_idx += 1
|
||
while ref_idx < ref_len:
|
||
alignment.append((ref_idx, None))
|
||
ref_idx += 1
|
||
while hyp_idx < hyp_len:
|
||
alignment.append((None, hyp_idx))
|
||
hyp_idx += 1
|
||
|
||
return alignment
|
||
|
||
|
||
def align_tokens(ref_text: List[str], hyp_text: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
|
||
"""
|
||
计算分词后的两个字符串的对齐
|
||
返回:[(ref_idx, hyp_idx), ...]
|
||
"""
|
||
matcher = SequenceMatcher(None, ref_text, hyp_text)
|
||
alignment = []
|
||
|
||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||
if tag == 'equal' or tag == 'replace':
|
||
for r, h in zip(range(i1, i2), range(j1, j2)):
|
||
alignment.append((r, h))
|
||
elif tag == 'delete':
|
||
for r in range(i1, i2):
|
||
alignment.append((r, None))
|
||
elif tag == 'insert':
|
||
for h in range(j1, j2):
|
||
alignment.append((None, h))
|
||
return alignment
|
||
|
||
|
||
def find_tail_word_time(
|
||
ref_text: List[str],
|
||
pred_text: List[str],
|
||
merged_words: List[WordModel],
|
||
char2word_idx: List[int],
|
||
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||
) -> Optional[WordModel]:
|
||
# alignment = align_texts(ref_text, merged_text)
|
||
"""
|
||
根据标答文本 ref_text 找尾字(非标点)
|
||
通过 alignment 找对应合成文本 pred_text 尾字索引
|
||
再通过 char2word_idx 找对应word索引,返回对应的 WordModel
|
||
|
||
:param ref_text: 标答文本字符列表
|
||
:param pred_text: 合成文本字符列表
|
||
:param merged_words: 合成文本对应的WordModel列表
|
||
:param char2word_idx: 合成文本每个字符对应的WordModel索引
|
||
:param alignment: ref_text和pred_text的字符对齐列表 (ref_idx, hyp_idx)
|
||
:return: 对应尾字的WordModel 或 None
|
||
"""
|
||
punct_set = set(",。!?、,.!?;;")
|
||
|
||
# 1. 找到ref_text中最后一个非标点字符的索引 tail_ref_idx
|
||
tail_ref_idx = len(ref_text) - 1
|
||
while tail_ref_idx >= 0 and ref_text[tail_ref_idx] in punct_set:
|
||
tail_ref_idx -= 1
|
||
|
||
if tail_ref_idx < 0:
|
||
# 全是标点,找不到尾字
|
||
return None
|
||
|
||
# 2. 找 alignment 中 ref_idx == tail_ref_idx 对应的 hyp_idx
|
||
tail_hyp_idx = None
|
||
for ref_idx, hyp_idx in reversed(alignment):
|
||
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||
tail_hyp_idx = hyp_idx
|
||
break
|
||
|
||
if tail_hyp_idx is None:
|
||
# 没有对应的hyp_idx
|
||
return None
|
||
|
||
# 3. hyp_idx 超出范围
|
||
if tail_hyp_idx >= len(char2word_idx):
|
||
return None
|
||
|
||
# 4. 通过 char2word_idx 找对应 word 索引
|
||
word_index = char2word_idx[tail_hyp_idx]
|
||
|
||
if word_index >= len(merged_words):
|
||
return None
|
||
|
||
# 5. 返回对应的WordModel
|
||
return merged_words[word_index]
|
||
|
||
|
||
def find_head_word_time(
|
||
ref_text: List[str],
|
||
pred_text: List[str],
|
||
merged_words: List[WordModel],
|
||
char2word_idx: List[int],
|
||
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||
) -> Optional[WordModel]:
|
||
"""
|
||
找标答首字在ASR中的start_time
|
||
|
||
参数:
|
||
ref_text:标答完整文本
|
||
merged_text:ASR合并后的完整文本(逐字)
|
||
merged_words:ASR合并的WordModel列表
|
||
char2word_idx:字符到词的映射索引列表
|
||
|
||
返回:
|
||
找到的首字对应词的start_time(毫秒),没找到返回None
|
||
"""
|
||
# alignment = align_texts(ref_text, merged_text)
|
||
|
||
ref_head_index = 0 # 首字索引固定0
|
||
|
||
for ref_idx, hyp_idx in alignment:
|
||
if ref_idx == ref_head_index and hyp_idx is not None:
|
||
if 0 <= hyp_idx < len(char2word_idx):
|
||
word_idx = char2word_idx[hyp_idx]
|
||
return merged_words[word_idx]
|
||
return None
|
||
|
||
|
||
def merge_asr_results(asr_list: List[SegmentModel]) -> Tuple[str, List[WordModel], List[int]]:
|
||
"""
|
||
合并多个 ASRResultModel 成一个大文本和 word 列表,同时建立每个字符对应的 WordModel 索引
|
||
返回:
|
||
- 合并文本 merged_text
|
||
- WordModel 列表 merged_words
|
||
- 每个字符所在 WordModel 的索引 char2word_idx
|
||
"""
|
||
# merged_text = ""
|
||
# merged_words = []
|
||
# char2word_idx = []
|
||
#
|
||
# for asr in asr_list:
|
||
# if not asr.text or not asr.words:
|
||
# continue
|
||
# merged_text += asr.text
|
||
# for word in asr.words:
|
||
# word.segment = asr
|
||
# merged_words.append(word)
|
||
# for ch in word.text:
|
||
# char2word_idx.append(len(merged_words) - 1)
|
||
# return merged_text, merged_words, char2word_idx
|
||
"""
|
||
合并多个 ASRResultModel 成一个大文本和 word 列表,
|
||
去掉标点符号,建立每个字符对应的 WordModel 索引
|
||
|
||
返回:
|
||
- 去标点后的合并文本 merged_text
|
||
- WordModel 列表 merged_words(包含标点)
|
||
- 去标点后的每个字符对应 WordModel 的索引 char2word_idx
|
||
"""
|
||
punct_set = set(",。!?、,.!?;;") # 需要过滤的标点集合
|
||
|
||
merged_text = ""
|
||
merged_words = []
|
||
char2word_idx = []
|
||
|
||
for asr in asr_list:
|
||
if not asr.text or not asr.words:
|
||
continue
|
||
merged_words_start_len = len(merged_words)
|
||
for word in asr.words:
|
||
word.segment = asr
|
||
merged_words.append(word)
|
||
|
||
# 遍历所有word,拼接时去掉标点,同时维护 char2word_idx
|
||
for idx_in_asr, word in enumerate(asr.words):
|
||
word_idx = merged_words_start_len + idx_in_asr
|
||
for ch in word.text:
|
||
if ch not in punct_set:
|
||
merged_text += ch
|
||
char2word_idx.append(word_idx)
|
||
|
||
return merged_text, merged_words, char2word_idx
|
||
|
||
|
||
def rebuild_char2word_idx(pred_tokens: List[str], merged_words: List[WordModel]) -> List[int]:
|
||
"""
|
||
重新构建 char2word_idx,使其与 pred_tokens 一一对应
|
||
"""
|
||
char2word_idx = []
|
||
word_char_idx = 0
|
||
for word_idx, word in enumerate(merged_words):
|
||
for _ in word.text:
|
||
if word_char_idx < len(pred_tokens):
|
||
char2word_idx.append(word_idx)
|
||
word_char_idx += 1
|
||
return char2word_idx
|
||
|
||
|
||
def build_hyp_token_to_asr_chart_index(
|
||
hyp_tokens: List[str],
|
||
asr_words: List[WordModel]
|
||
) -> List[int]:
|
||
"""
|
||
建立从 hyp_token 索引到 asr_word 索引的映射
|
||
假设 asr_words 的 text 组成 hyp_tokens 的连续子串(简单匹配)
|
||
"""
|
||
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
|
||
|
||
i_asr = 0
|
||
i_hyp = 0
|
||
|
||
while i_asr < len(asr_words) and i_hyp < len(hyp_tokens):
|
||
asr_word = asr_words[i_asr].text
|
||
length = len(asr_word)
|
||
# 拼接 hyp_tokens 从 i_hyp 开始的 length 个 token
|
||
hyp_substr = "".join(hyp_tokens[i_hyp:i_hyp + length])
|
||
if hyp_substr == asr_word:
|
||
# 匹配成功,建立映射
|
||
for k in range(i_hyp, i_hyp + length):
|
||
hyp_to_asr_word_idx[k] = i_asr
|
||
i_hyp += length
|
||
i_asr += 1
|
||
else:
|
||
# 如果不匹配,尝试扩大或缩小匹配长度(容错)
|
||
# 也可以根据具体情况改进此逻辑
|
||
# 这里简化处理,跳过一个hyp token
|
||
i_hyp += 1
|
||
|
||
return hyp_to_asr_word_idx
|
||
|
||
import re
|
||
|
||
def normalize(text: str) -> str:
|
||
return re.sub(r"[^\w']+", '', text.lower()) # 去除非单词字符,保留撇号
|
||
|
||
def build_hyp_token_to_asr_word_index(hyp_tokens: List[str], asr_words: List[WordModel]) -> List[int]:
|
||
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
|
||
i_hyp, i_asr = 0, 0
|
||
|
||
while i_hyp < len(hyp_tokens) and i_asr < len(asr_words):
|
||
hyp_token = normalize(hyp_tokens[i_hyp])
|
||
asr_word = normalize(asr_words[i_asr].text)
|
||
|
||
# 匹配包含/前缀关系,提高鲁棒性
|
||
if hyp_token == asr_word or hyp_token in asr_word or asr_word in hyp_token:
|
||
hyp_to_asr_word_idx[i_hyp] = i_asr
|
||
i_hyp += 1
|
||
i_asr += 1
|
||
else:
|
||
i_hyp += 1
|
||
|
||
return hyp_to_asr_word_idx
|
||
|
||
def find_tail_word(
|
||
ref_tokens: List[str], # 参考文本token列表
|
||
hyp_tokens: List[str], # 预测文本token列表
|
||
alignment: List[Tuple[Optional[int], Optional[int]]], # (ref_idx, hyp_idx)对齐结果
|
||
hyp_to_asr_word_idx: dict,
|
||
asr_words: List[WordModel],
|
||
punct_set: set = set(",。!?、,.!?;;")
|
||
) -> Optional[WordModel]:
|
||
"""
|
||
通过参考文本尾token,定位对应预测token,再映射到ASR词,拿时间
|
||
"""
|
||
|
||
"""
|
||
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word)
|
||
"""
|
||
|
||
# 1. 去掉 ref 尾部标点,找到 ref 尾词 index
|
||
tail_ref_idx = len(ref_tokens) - 1
|
||
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
|
||
tail_ref_idx -= 1
|
||
if tail_ref_idx < 0:
|
||
return None
|
||
|
||
# 2. 在 alignment 中找到对应的 hyp_idx
|
||
tail_hyp_idx = None
|
||
for ref_idx, hyp_idx in reversed(alignment):
|
||
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||
tail_hyp_idx = hyp_idx
|
||
break
|
||
|
||
# 3. 如果找不到,退一步找最后一个有匹配的 ref_idx
|
||
if tail_hyp_idx is None:
|
||
for ref_idx, hyp_idx in reversed(alignment):
|
||
if hyp_idx is not None:
|
||
tail_hyp_idx = hyp_idx
|
||
break
|
||
|
||
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
|
||
return None
|
||
|
||
# 4. 映射到 ASR word index
|
||
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
|
||
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||
return None
|
||
|
||
return asr_words[asr_word_idx]
|
||
|
||
|
||
def find_tail_word2(
|
||
ref_tokens: List[str], # 标答token列表
|
||
hyp_tokens: List[str], # 预测token列表
|
||
alignment: List[Tuple[Optional[int], Optional[int]]], # 对齐 (ref_idx, hyp_idx)
|
||
hyp_to_asr_word_idx: List[int], # hyp token 对应的 ASR word 索引
|
||
asr_words: List[WordModel],
|
||
punct_set: set = set(",。!?、,.!?;;"),
|
||
enable_debug: bool = False
|
||
) -> Optional[WordModel]:
|
||
"""
|
||
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word)
|
||
|
||
返回 None 表示没找到
|
||
"""
|
||
# Step 1. 找到 ref_tokens 中最后一个非标点的索引
|
||
tail_ref_idx = len(ref_tokens) - 1
|
||
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
|
||
tail_ref_idx -= 1
|
||
if tail_ref_idx < 0:
|
||
if enable_debug:
|
||
print("全是标点,尾字找不到")
|
||
return None
|
||
|
||
# Step 2. alignment 中查找 tail_ref_idx 对应的 hyp_idx
|
||
tail_hyp_idx = None
|
||
for ref_idx, hyp_idx in reversed(alignment):
|
||
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||
tail_hyp_idx = hyp_idx
|
||
break
|
||
|
||
# Step 3. fallback:如果找不到,向前找最近一个非标点且能对齐的 ref_idx
|
||
fallback_idx = tail_ref_idx
|
||
while tail_hyp_idx is None and fallback_idx >= 0:
|
||
if ref_tokens[fallback_idx] not in punct_set:
|
||
for ref_idx, hyp_idx in reversed(alignment):
|
||
if ref_idx == fallback_idx and hyp_idx is not None:
|
||
tail_hyp_idx = hyp_idx
|
||
break
|
||
fallback_idx -= 1
|
||
|
||
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
|
||
if enable_debug:
|
||
print(f"tail_hyp_idx 无法找到或超出范围: {tail_hyp_idx}")
|
||
return None
|
||
|
||
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
|
||
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||
if enable_debug:
|
||
print(f"asr_word_idx 无效: {asr_word_idx}")
|
||
return None
|
||
|
||
return asr_words[asr_word_idx]
|
||
|
||
|
||
def find_head_word(
|
||
ref_tokens: List[str],
|
||
hyp_tokens: List[str],
|
||
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||
hyp_to_asr_word_idx: dict,
|
||
asr_words: List[WordModel],
|
||
punct_set: set = set(",。!?、,.!?;;")
|
||
) -> Optional[WordModel]:
|
||
"""
|
||
通过参考文本开头第一个非标点token,定位对应预测token,再映射到ASR词,拿时间
|
||
"""
|
||
|
||
# 1. 找到参考文本开头非标点索引
|
||
head_ref_idx = 0
|
||
while head_ref_idx < len(ref_tokens) and ref_tokens[head_ref_idx] in punct_set:
|
||
head_ref_idx += 1
|
||
if head_ref_idx >= len(ref_tokens):
|
||
return None
|
||
|
||
# 2. 找到 alignment 中对应的 hyp_idx
|
||
head_hyp_idx = None
|
||
for ref_idx, hyp_idx in alignment:
|
||
if ref_idx == head_ref_idx and hyp_idx is not None:
|
||
head_hyp_idx = hyp_idx
|
||
break
|
||
|
||
if head_hyp_idx is None or head_hyp_idx >= len(hyp_to_asr_word_idx):
|
||
return None
|
||
|
||
# 3. 映射到 asr_words 的索引
|
||
asr_word_idx = hyp_to_asr_word_idx[head_hyp_idx]
|
||
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||
return None
|
||
|
||
return asr_words[asr_word_idx]
|
||
|
||
|
||
def calculate_sentence_delay(
|
||
datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str = "zh"
|
||
) -> (float, float, float):
|
||
"""
|
||
|
||
:param datas: 标答和模型结果
|
||
:return: 未找到尾字的比例,修正0的数量,平均延迟时间。
|
||
"""
|
||
tail_offset_time = 0 # 尾字偏移时间
|
||
standard_offset_time = 0 # 尾字偏移时间
|
||
tail_not_found = 0 # 未找到尾字数量
|
||
tail_found = 0 # 找到尾字数量
|
||
|
||
standard_fix = 0
|
||
final_fix = 0
|
||
|
||
head_offset_time = 0 # 尾字偏移时间
|
||
final_offset_time = 0 # 尾字偏移时间
|
||
head_not_found = 0 # 未找到尾字数量
|
||
head_found = 0 # 找到尾字数量
|
||
|
||
for audio_item, asr_list in datas:
|
||
if not audio_item.voice:
|
||
continue
|
||
# (以防万一)将标答中所有的文本连起来,并将标答中最后一条信息的结束时间作为结束时间。
|
||
ref_text = ""
|
||
for voice in audio_item.voice:
|
||
ref_text = ref_text + voice.answer.strip()
|
||
if not ref_text:
|
||
continue
|
||
logger.debug(f"-=-=-=-=-=-=-=-=-=-=-=-=-=start-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
|
||
|
||
ref_end_ms = int(audio_item.voice[-1].end * 1000)
|
||
ref_start_ms = int(audio_item.voice[0].start * 1000)
|
||
|
||
# 录音所有的片段都过一下,text的end和receive的end
|
||
# 统计定稿时间。 接收到segment的时间 - segment.end_time
|
||
#
|
||
# pred_text = ""
|
||
# asr_words: List[WordModel] = []
|
||
# for asr in asr_list:
|
||
# pred_text = pred_text + asr.text
|
||
# final_offset = asr.receive_time - asr.end_time
|
||
# asr_words = asr_words + asr.words
|
||
# for word in asr.words:
|
||
# word.segment = asr
|
||
# if final_offset < 0:
|
||
# final_fix = final_fix + 1
|
||
# # 统计被修复的数量
|
||
# final_offset = 0
|
||
# # 统计定稿偏移时间
|
||
# final_offset_time = final_offset_time + final_offset
|
||
|
||
pred_text = []
|
||
asr_words: List[WordModel] = []
|
||
temp_final_offset_time = 0
|
||
for asr in asr_list:
|
||
pred_text = pred_text + [word.text for word in asr.words]
|
||
final_offset = asr.receive_time - asr.end_time
|
||
logger.debug(f"asr.receive_time {asr.receive_time} , asr.end_time {asr.end_time} , final_offset {final_offset}")
|
||
asr_words = asr_words + asr.words
|
||
for word in asr.words:
|
||
word.segment = asr
|
||
if final_offset < 0:
|
||
final_fix = final_fix + 1
|
||
# 统计被修复的数量
|
||
final_offset = 0
|
||
# 统计定稿偏移时间
|
||
temp_final_offset_time = temp_final_offset_time + final_offset
|
||
final_offset_time = final_offset_time + temp_final_offset_time / len(asr_list)
|
||
|
||
# 处理模型给出的结果。
|
||
logger.debug(f"text: {ref_text},pred_text: {pred_text}")
|
||
# 计算对应关系
|
||
|
||
# pred_tokens 是与原文一致的,只是可能多了几个为空的位置。需要打平为一维数组,并记录对应的word的位置。
|
||
flat_pred_tokens = []
|
||
hyp_to_asr_word_idx = {} # key: flat_pred_token_index -> asr_word_index
|
||
|
||
alignment_type = get_alignment_type(language)
|
||
if alignment_type == "chart":
|
||
label_tokens = Tokenizer.tokenize([ref_text], language)[0]
|
||
pred_tokens = Tokenizer.tokenize(pred_text, language)
|
||
for asr_idx, token_group in enumerate(pred_tokens):
|
||
for token in token_group:
|
||
flat_pred_tokens.append(token)
|
||
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
|
||
alignment = align_texts(label_tokens, "".join(flat_pred_tokens))
|
||
else:
|
||
label_tokens = Tokenizer.norm_and_tokenize([ref_text], language)[0]
|
||
pred_tokens = Tokenizer.norm_and_tokenize(pred_text, language)
|
||
for asr_idx, token_group in enumerate(pred_tokens):
|
||
for token in token_group:
|
||
flat_pred_tokens.append(token)
|
||
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
|
||
alignment = align_tokens(label_tokens, flat_pred_tokens)
|
||
|
||
logger.debug(f"ref_tokens: {label_tokens}")
|
||
logger.debug(f"pred_tokens: {pred_tokens}")
|
||
logger.debug(f"alignment sample: {alignment[:30]}") # 只打印前30个,避免日志过大
|
||
logger.debug(f"hyp_to_asr_word_idx: {hyp_to_asr_word_idx}")
|
||
|
||
head_word_info = find_head_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
|
||
|
||
if head_word_info is None:
|
||
# 统计没有找到首字的数量
|
||
head_not_found = head_not_found + 1
|
||
logger.debug(f"未找到首字")
|
||
else:
|
||
logger.debug(f"head_word: {head_word_info.text} ref_start_ms:{ref_start_ms}")
|
||
# 找到首字
|
||
# 统计首字偏移时间 首字在策略中出现的word的时间 - 标答start_time
|
||
head_offset_time = head_offset_time + abs(head_word_info.start_time - ref_start_ms)
|
||
# 统计找到首字的数量
|
||
head_found += 1
|
||
|
||
# 找尾字所在的模型返回words信息。
|
||
|
||
tail_word_info = find_tail_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
|
||
if tail_word_info is None:
|
||
# 没有找到尾字,记录数量
|
||
tail_not_found = tail_not_found + 1
|
||
logger.debug(f"未找到尾字")
|
||
else:
|
||
# 找到尾字了
|
||
logger.debug(f"tail_word: {tail_word_info.text} ref_end_ms: {ref_end_ms}")
|
||
# 统计尾字偏移时间 标答的end_time - 策略尾字所在word的end_time
|
||
tail_offset_time = abs(ref_end_ms - tail_word_info.end_time) + tail_offset_time
|
||
|
||
# 统计标答句延迟时间 策略尾字所在word的实际接收时间 - 标答句end时间
|
||
standard_offset = tail_word_info.segment.receive_time - ref_end_ms
|
||
logger.debug(f"tail_word_info.segment.receive_time {tail_word_info.segment.receive_time } , tail_word_info.end_time {tail_word_info.end_time} , ref_end_ms {ref_end_ms}")
|
||
# 如果小于0修正为0
|
||
if standard_offset < 0:
|
||
standard_offset = 0
|
||
# 统计被修正的数量
|
||
standard_fix = standard_fix + 1
|
||
standard_offset_time = standard_offset + standard_offset_time
|
||
|
||
# 统计找到尾字的数量
|
||
tail_found += 1
|
||
|
||
logger.info(f"-=-=-=-=-=-=-=-=-=-=-=-=-=end-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
|
||
|
||
logger.debug(
|
||
f"找到首字字数量: {head_found},未找到首字数量:{head_not_found},找到尾字数量: {tail_found},未找到尾字数量:{tail_not_found},修正标答偏移负数数量:{standard_fix},修正定稿偏移负数数量:{final_fix},")
|
||
logger.debug(
|
||
f"尾字偏移总时间:{tail_offset_time},标答句偏移总时间:{standard_offset_time},首字偏移总时间:{head_offset_time},定稿偏移总时间:{final_offset_time},")
|
||
|
||
#
|
||
# 统计平均值
|
||
head_not_found_ratio = head_not_found / (head_found + head_not_found)
|
||
tail_not_found_ratio = tail_not_found / (tail_found + tail_not_found)
|
||
|
||
average_tail_offset = tail_offset_time / tail_found / 1000
|
||
average_head_offset = head_offset_time / head_found / 1000
|
||
|
||
average_standard_offset = standard_offset_time / tail_found / 1000
|
||
average_final_offset = final_offset_time / tail_found / 1000
|
||
|
||
logger.info(
|
||
f"首字未找到比例:{head_not_found_ratio},尾字未找到比例:{tail_not_found_ratio},首字偏移时间:{average_head_offset},尾字偏移时间:{average_tail_offset},标答句偏移时间:{average_standard_offset},定稿偏移时间:{average_final_offset}")
|
||
|
||
return head_not_found_ratio, average_head_offset, tail_not_found_ratio, average_standard_offset, average_final_offset, average_tail_offset
|
||
|
||
|
||
if __name__ == '__main__':
|
||
checks = [
|
||
{
|
||
"type": "zh",
|
||
"ref": "今天天气真好",
|
||
"hyp": "今天真好"
|
||
},
|
||
{
|
||
"type": "zh",
|
||
"ref": "我喜欢吃苹果",
|
||
"hyp": "我很喜欢吃香蕉"
|
||
},
|
||
{
|
||
"type": "zh",
|
||
"ref": "我喜欢吃苹果",
|
||
"hyp": "我喜欢吃苹果"
|
||
},
|
||
{
|
||
"type": "en",
|
||
"ref": "I like to eat apples",
|
||
"hyp": "I really like eating apples"
|
||
},
|
||
{
|
||
"type": "en",
|
||
"ref": "She is going to the market",
|
||
"hyp": "She went market"
|
||
},
|
||
{
|
||
"type": "en",
|
||
"ref": "Hello world",
|
||
"hyp": "Hello world"
|
||
},
|
||
{
|
||
"type": "en",
|
||
"ref": "Good morning",
|
||
"hyp": "Bad night"
|
||
},
|
||
]
|
||
for check in checks:
|
||
ref = check.get("ref")
|
||
type = check.get("type")
|
||
hyp = check.get("hyp")
|
||
|
||
res1 = align_texts(ref, hyp)
|
||
|
||
res2 = align_tokens(list(ref), list(hyp))
|
||
|
||
from utils.tokenizer import Tokenizer
|
||
|
||
start = time.time()
|
||
tokens_pred = Tokenizer.norm_and_tokenize([ref], type)
|
||
print(time.time() - start)
|
||
|
||
start = time.time()
|
||
Tokenizer.norm_and_tokenize([ref + ref + ref], type)
|
||
print(time.time() - start)
|
||
|
||
tokens_label = Tokenizer.norm_and_tokenize([hyp], type)
|
||
print(tokens_pred)
|
||
print(tokens_label)
|
||
res3 = align_tokens(tokens_pred[0], tokens_label[0])
|
||
print(res1 == res2)
|
||
print(res1 == res3)
|