Files
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

446 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import subprocess
from collections import defaultdict
from typing import Dict, List
from utils import asr_ter
from utils.logger import logger
log_mid_result = int(os.getenv("log", 0)) == 1
class AsrEvaluator:
def __init__(self) -> None:
self.query_count = 0 # query 数目(语音数目)
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
# cer 属性
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
self.token_count = 0 # 每个 query 的字数/词数和
# 句子切分率属性
self.miss_count = 0 # 每个 query miss-count 和
self.more_count = 0 # 每个 query more-count 和
self.cut_count = 0 # 每个 query cut-count 和
self.rate = 0 # 每个 query 的 cut-rate 和
# detail case
self.result = []
def evaluate(self, eval_result):
pass
def post_evaluate(self):
pass
def gen_result(self) -> Dict:
output_result = dict()
output_result["query_count"] = self.query_count
output_result["voice_count"] = self.voice_count
output_result["token_cnt"] = self.token_count
output_result["one_minus_cer"] = self.one_minus_cer
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
output_result["miss_count"] = self.miss_count
output_result["more_count"] = self.more_count
output_result["cut_count"] = self.cut_count
output_result["cut_rate"] = self.rate
output_result["cut_rate_metrics"] = self.rate / self.query_count
output_result["rtf"] = self.rtf
output_result["rtf_end"] = self.rtf_end
output_result["rtf_metrics"] = self.rtf / self.voice_count
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
detail_case = self.result
return output_result, detail_case
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
"""
获取 predict data 数据,然后将其中 final 的句子拿出来,放到列表里。
"""
return [
item["recoginition_results"]["text"]
for item in predict_data
if item["recoginition_results"]["final_result"]
]
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
"""
获取 sentence 结束的字对应的 token 索引值。
"""
token_index_list = []
token_idx = 0
for sentence in sentences:
for token in Tokenizer.tokenize(sentence, tokenizer):
if token not in tokens:
continue
while tokens[token_idx] != token:
token_idx += 1
token_index_list.append(token_idx)
return token_index_list
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
"""
将数据集的语音片段转换为最小切分单元列表。
使用 cut_punc 中的所有 punc 进行依次切分,最后去除掉完全空的内容
示例:
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
->
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
"""
voice_sentences_result = defaultdict(list)
for voice_sentence in voice_sentences:
sentence_list = [voice_sentence]
sentence_tmp_list = []
for punc in self.cut_punc:
for sentence in sentence_list:
sentence_tmp_list.extend(sentence.split(punc))
sentence_list, sentence_tmp_list = sentence_tmp_list, []
sentence_list = [item for item in sentence_list if item]
# 切分后的句子单元
voice_sentences_result["cut_sentences"].extend(sentence_list)
# 每个语音单元最后一个字对应的句子单元的索引
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
return voice_sentences_result
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
"""
timestamp: 时间, 单位秒
"""
bytes_per_sample = bit_depth // 8
return timestamp * sample_rate * bytes_per_sample * channels
class AsrZhEvaluator(AsrEvaluator):
"""
中文的评估方式
"""
def __init__(self):
super().__init__()
self.cut_zh_punc = ["······", "......", "", "", "", "", "", ""]
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# print("label_sentences", label_sentences)
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
# print("predict_sentences", predict_sentences)
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
# cer 计算
cer_info = self.cer(label_sentences, predict_sentences)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
# print("cer_info", cer_info)
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
# print("cut_info", cut_info)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_sentences, predict_sentences):
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
label_str = ''.join(label_sentences)
r = asr_ter.calc_ter_speechio(pred_str, label_str)
one_minus_cer = max(1.0 - r['ter'], 0)
token_count = r['ref_all_token_cnt']
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_norm(self, sentence, tokenizer="word"):
"""
对句子进行 norm 操作
"""
from utils.speechio import textnorm_zh as textnorm
if tokenizer == "word":
normalizer = textnorm.TextNorm(
to_banjiao=True,
to_upper=True,
to_lower=False,
remove_fillers=True,
remove_erhua=False, # 这里同批量识别不同,改成了 False
check_chars=False,
remove_space=False,
cc_mode="",
)
return normalizer(sentence)
else:
logger.error("tokenizer error, not support.")
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
"""
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
args:
label: "今天的通话质量不错呀昨天的呢"
predict: "今天的通话质量不错昨天呢星期"
tokenizer: 分词方式
return:
label: ["", "", "", "", "", "", "", "", "", "", "", "", "", "", None, None]
predict: ["", "", "", "", "", "", "", "", "", None, "", "", None, "", "", ""]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "char":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
# print(align.__dict__)
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def _pred_data_transfer(self, predict_data, recv_time):
"""
predict_data = [
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
]
recv_time = [1, 3, 5, 6, 7, 8]
->
[
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
]
"""
pred_sentence_info = []
pred_sentence_index = 0
for predict_item, recv_time_item in zip(predict_data, recv_time):
if len(pred_sentence_info) == pred_sentence_index:
pred_sentence_info.append([])
pred_sentence_info[pred_sentence_index].append(
{
"text": predict_item["recoginition_results"]["text"],
"time": recv_time_item,
}
)
if predict_item["recoginition_results"]["final_result"]:
pred_sentence_index += 1
return pred_sentence_info
class AsrEnEvaluator(AsrEvaluator):
"""
英文的评估方式
"""
def evaluate(self, eval_result) -> Dict:
self.query_count += 1
self.voice_count += len(eval_result["voice"])
# 获取,标注结果 & 语音单元(非句子单元)
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
# print("label_voice_sentences", label_voice_sentences)
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
# print("voice_to_cut_info", voice_to_cut_info)
# 获取,标注结果 & 句子单元
label_sentences = voice_to_cut_info["cut_sentences"]
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
# 标注结果 & 句子单元 & norm 操作
label_sentences = self._sentence_list_norm(label_sentences)
# [self._sentence_norm(sentence) for sentence in label_sentences]
# print("label_sentences", label_sentences)
if log_mid_result:
logger.info(f"label_sentences {label_sentences}")
# 预测结果 & 句子单元
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
# print("predict_sentences_raw", predict_sentences_raw)
# 预测结果 & 句子单元 & norm 操作
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
# print("predict_sentences", predict_sentences)
if log_mid_result:
logger.info(f"predict_sentences {predict_sentences}")
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
# print(label_tokens)
# print(predict_tokens)
# cer 计算
cer_info = self.cer(label_tokens, predict_tokens)
# print("cer_info", cer_info)
if log_mid_result:
logger.info(f"cer_info {cer_info}")
self.one_minus_cer += cer_info["one_minus_cer"]
self.token_count += cer_info["token_count"]
# 句子切分准召率
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
# print("cut_info", cut_info)
if log_mid_result:
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
self.miss_count += cut_info["miss_count"]
self.more_count += cut_info["more_count"]
self.cut_count += cut_info["cut_count"]
self.rate += cut_info["rate"]
self.result.append(
{
"label_tokens": label_tokens,
"predict_tokens": predict_tokens,
"one_minus_cer": cer_info["one_minus_cer"],
"token_count": cer_info["one_minus_cer"],
"miss_count": cut_info["miss_count"],
"more_count": cut_info["more_count"],
"cut_count": cut_info["cut_count"],
"rate": cut_info["rate"],
}
)
def cer(self, label_tokens, predict_tokens):
s, d, i, c = 0, 0, 0, 0
for label_token, predict_token in zip(label_tokens, predict_tokens):
if label_token == predict_token:
c += 1
elif predict_token is None:
d += 1
elif label_token is None:
i += 1
else:
s += 1
cer = (s + d + i) / (s + d + c)
one_minus_cer = max(1.0 - cer, 0)
token_count = s + d + c
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
label_sentence_count = len(label_final_index_list)
miss_count = len(label_final_index_list - pred_final_index_list)
more_count = len(pred_final_index_list - label_final_index_list)
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
return {
"miss_count": miss_count,
"more_count": more_count,
"cut_count": label_sentence_count,
"rate": rate,
"label_final_index_list": label_final_index_list,
"pred_final_index_list": pred_final_index_list,
}
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open('./predict.txt', 'w', encoding='utf-8') as fp:
for idx, sentence in enumerate(sentence_list):
fp.write('%s\t%s\n' % (idx, sentence))
subprocess.run(
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
shell=True,
check=True,
)
sentence_norm = []
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
for line in fp.readlines():
line_split_result = line.strip().split('\t', 1)
if len(line_split_result) >= 2:
sentence_norm.append(line_split_result[1])
# 有可能没有 norm 后就没了
return sentence_norm
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
"""
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
args:
label: "HELLO WORLD ARE U OK YEP"
predict: "HELLO WORLD U ARE U OK YEP"
tokenizer: 分词方式
return:
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
"""
from utils.speechio import error_rate_zh as error_rate
if tokenizer == "whitespace":
alignment, score = error_rate.EditDistance(
error_rate.tokenize_text(label_sentence, tokenizer),
error_rate.tokenize_text(predict_sentence, tokenizer),
)
label_tokens, pred_tokens = [], []
for align in alignment:
label_tokens.append(align.ref)
pred_tokens.append(align.hyp)
return (label_tokens, pred_tokens)
else:
logger.error("tokenizer 出错了,暂时不支持其它的")
def post_evaluate(self) -> Dict:
pass