Files

446 lines
20 KiB
Python
Raw Permalink Normal View History

2025-08-06 15:38:55 +08:00
import os
import subprocess
from collections import defaultdict
from typing import Dict, List
from utils import asr_ter
from utils.logger import logger
log_mid_result = int(os.getenv("log", 0)) == 1
class AsrEvaluator:
def __init__(self) -> None:
self.query_count = 0 # query 数目(语音数目)
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
# cer 属性
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
self.token_count = 0 # 每个 query 的字数/词数和
# 句子切分率属性
self.miss_count = 0 # 每个 query miss-count 和
self.more_count = 0 # 每个 query more-count 和
self.cut_count = 0 # 每个 query cut-count 和
self.rate = 0 # 每个 query 的 cut-rate 和
# detail case
self.result = []
def evaluate(self, eval_result):
pass
def post_evaluate(self):
pass
def gen_result(self) -> Dict:
output_result = dict()
output_result["query_count"] = self.query_count
output_result["voice_count"] = self.voice_count
output_result["token_cnt"] = self.token_count
output_result["one_minus_cer"] = self.one_minus_cer
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
output_result["miss_count"] = self.miss_count
output_result["more_count"] = self.more_count
output_result["cut_count"] = self.cut_count
output_result["cut_rate"] = self.rate
output_result["cut_rate_metrics"] = self.rate / self.query_count
output_result["rtf"] = self.rtf
output_result["rtf_end"] = self.rtf_end
output_result["rtf_metrics"] = self.rtf / self.voice_count
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
detail_case = self.result
return output_result, detail_case
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
"""
获取 predict data 数据然后将其中 final 的句子拿出来放到列表里
"""
return [
item["recoginition_results"]["text"]
for item in predict_data
if item["recoginition_results"]["final_result"]
]
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
"""
获取 sentence 结束的字对应的 token 索引值
"""
token_index_list = []
token_idx = 0
for sentence in sentences:
for token in Tokenizer.tokenize(sentence, tokenizer):
if token not in tokens:
continue
while tokens[token_idx] != token:
token_idx += 1
token_index_list.append(token_idx)
return token_index_list
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
"""
将数据集的语音片段转换为最小切分单元列表
使用 cut_punc 中的所有 punc 进行依次切分最后去除掉完全空的内容
示例
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
->
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
"""
voice_sentences_result = defaultdict(list)
for voice_sentence in voice_sentences:
sentence_list = [voice_sentence]
sentence_tmp_list = []
for punc in self.cut_punc:
for sentence in sentence_list:
sentence_tmp_list.extend(sentence.split(punc))
sentence_list, sentence_tmp_list = sentence_tmp_list, []
sentence_list = [item for item in sentence_list if item]
# 切分后的句子单元
voice_sentences_result["cut_sentences"].extend(sentence_list)
# 每个语音单元最后一个字对应的句子单元的索引
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
return voice_sentences_result
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
"""
timestamp: 时间, 单位秒
"""
bytes_per_sample = bit_depth // 8
return timestamp * sample_rate * bytes_per_sample * channels
class AsrZhEvaluator(AsrEvaluator):
"""
中文的评估方式
"""
def __init__(self):
super().__init__()
self.cut_zh_punc = ["······", "......", "", "", "", "", "", ""]
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# print("label_sentences", label_sentences)
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
# print("predict_sentences", predict_sentences)
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
# cer 计算
cer_info = self.cer(label_sentences, predict_sentences)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
# print("cer_info", cer_info)
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
# print("cut_info", cut_info)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_sentences, predict_sentences):
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
label_str = ''.join(label_sentences)
r = asr_ter.calc_ter_speechio(pred_str, label_str)
one_minus_cer = max(1.0 - r['ter'], 0)
token_count = r['ref_all_token_cnt']
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_norm(self, sentence, tokenizer="word"):
"""
对句子进行 norm 操作
"""
from utils.speechio import textnorm_zh as textnorm
if tokenizer == "word":
normalizer = textnorm.TextNorm(
to_banjiao=True,
to_upper=True,
to_lower=False,
remove_fillers=True,
remove_erhua=False, # 这里同批量识别不同,改成了 False
check_chars=False,
remove_space=False,
cc_mode="",
)
return normalizer(sentence)
else:
logger.error("tokenizer error, not support.")
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
"""
基于最小编辑距离 label predict 进行字的位置匹配并生成转换后的结果
args:
label: "今天的通话质量不错呀昨天的呢"
predict: "今天的通话质量不错昨天呢星期"
tokenizer: 分词方式
return:
label: ["", "", "", "", "", "", "", "", "", "", "", "", "", "", None, None]
predict: ["", "", "", "", "", "", "", "", "", None, "", "", None, "", "", ""]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "char":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
# print(align.__dict__)
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def _pred_data_transfer(self, predict_data, recv_time):
"""
predict_data = [
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
]
recv_time = [1, 3, 5, 6, 7, 8]
->
[
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
]
"""
pred_sentence_info = []
pred_sentence_index = 0
for predict_item, recv_time_item in zip(predict_data, recv_time):
if len(pred_sentence_info) == pred_sentence_index:
pred_sentence_info.append([])
pred_sentence_info[pred_sentence_index].append(
{
"text": predict_item["recoginition_results"]["text"],
"time": recv_time_item,
}
)
if predict_item["recoginition_results"]["final_result"]:
pred_sentence_index += 1
return pred_sentence_info
class AsrEnEvaluator(AsrEvaluator):
"""
英文的评估方式
"""
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = self._sentence_list_norm(label_sentences)
# [self._sentence_norm(sentence) for sentence in label_sentences]
# print("label_sentences", label_sentences)
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
# print("predict_sentences", predict_sentences)
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
# print(label_tokens)
# print(predict_tokens)
# cer 计算
cer_info = self.cer(label_tokens, predict_tokens)
# print("cer_info", cer_info)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
# print("cut_info", cut_info)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_tokens, predict_tokens):
s, d, i, c = 0, 0, 0, 0
for label_token, predict_token in zip(label_tokens, predict_tokens):
if label_token == predict_token:
c += 1
elif predict_token is None:
d += 1
elif label_token is None:
i += 1
else:
s += 1
cer = (s + d + i) / (s + d + c)
one_minus_cer = max(1.0 - cer, 0)
token_count = s + d + c
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open('./predict.txt', 'w', encoding='utf-8') as fp:
for idx, sentence in enumerate(sentence_list):
fp.write('%s\t%s\n' % (idx, sentence))
subprocess.run(
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
shell=True,
check=True,
)
sentence_norm = []
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
for line in fp.readlines():
line_split_result = line.strip().split('\t', 1)
if len(line_split_result) >= 2:
sentence_norm.append(line_split_result[1])
# 有可能没有 norm 后就没了
return sentence_norm
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
"""
基于最小编辑距离 label predict 进行字的位置匹配并生成转换后的结果
args:
label: "HELLO WORLD ARE U OK YEP"
predict: "HELLO WORLD U ARE U OK YEP"
tokenizer: 分词方式
return:
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "whitespace":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def post_evaluate(self) -> Dict:
pass