update
This commit is contained in:
445
utils/evaluate.py
Normal file
445
utils/evaluate.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from utils import asr_ter
|
||||
from utils.logger import logger
|
||||
|
||||
log_mid_result = int(os.getenv("log", 0)) == 1
|
||||
|
||||
|
||||
class AsrEvaluator:
|
||||
def __init__(self) -> None:
|
||||
self.query_count = 0 # query 数目(语音数目)
|
||||
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
|
||||
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
|
||||
# cer 属性
|
||||
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
|
||||
self.token_count = 0 # 每个 query 的字数/词数和
|
||||
# 句子切分率属性
|
||||
self.miss_count = 0 # 每个 query miss-count 和
|
||||
self.more_count = 0 # 每个 query more-count 和
|
||||
self.cut_count = 0 # 每个 query cut-count 和
|
||||
self.rate = 0 # 每个 query 的 cut-rate 和
|
||||
# detail case
|
||||
self.result = []
|
||||
|
||||
def evaluate(self, eval_result):
|
||||
pass
|
||||
|
||||
def post_evaluate(self):
|
||||
pass
|
||||
|
||||
def gen_result(self) -> Dict:
|
||||
output_result = dict()
|
||||
output_result["query_count"] = self.query_count
|
||||
output_result["voice_count"] = self.voice_count
|
||||
output_result["token_cnt"] = self.token_count
|
||||
output_result["one_minus_cer"] = self.one_minus_cer
|
||||
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
|
||||
output_result["miss_count"] = self.miss_count
|
||||
output_result["more_count"] = self.more_count
|
||||
output_result["cut_count"] = self.cut_count
|
||||
output_result["cut_rate"] = self.rate
|
||||
output_result["cut_rate_metrics"] = self.rate / self.query_count
|
||||
output_result["rtf"] = self.rtf
|
||||
output_result["rtf_end"] = self.rtf_end
|
||||
output_result["rtf_metrics"] = self.rtf / self.voice_count
|
||||
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
|
||||
|
||||
detail_case = self.result
|
||||
return output_result, detail_case
|
||||
|
||||
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
|
||||
"""
|
||||
获取 predict data 数据,然后将其中 final 的句子拿出来,放到列表里。
|
||||
"""
|
||||
return [
|
||||
item["recoginition_results"]["text"]
|
||||
for item in predict_data
|
||||
if item["recoginition_results"]["final_result"]
|
||||
]
|
||||
|
||||
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
|
||||
"""
|
||||
获取 sentence 结束的字对应的 token 索引值。
|
||||
"""
|
||||
token_index_list = []
|
||||
token_idx = 0
|
||||
for sentence in sentences:
|
||||
for token in Tokenizer.tokenize(sentence, tokenizer):
|
||||
if token not in tokens:
|
||||
continue
|
||||
while tokens[token_idx] != token:
|
||||
token_idx += 1
|
||||
token_index_list.append(token_idx)
|
||||
return token_index_list
|
||||
|
||||
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
|
||||
"""
|
||||
将数据集的语音片段转换为最小切分单元列表。
|
||||
使用 cut_punc 中的所有 punc 进行依次切分,最后去除掉完全空的内容
|
||||
示例:
|
||||
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
|
||||
->
|
||||
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
|
||||
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
|
||||
"""
|
||||
voice_sentences_result = defaultdict(list)
|
||||
for voice_sentence in voice_sentences:
|
||||
sentence_list = [voice_sentence]
|
||||
sentence_tmp_list = []
|
||||
for punc in self.cut_punc:
|
||||
for sentence in sentence_list:
|
||||
sentence_tmp_list.extend(sentence.split(punc))
|
||||
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||||
sentence_list = [item for item in sentence_list if item]
|
||||
# 切分后的句子单元
|
||||
voice_sentences_result["cut_sentences"].extend(sentence_list)
|
||||
# 每个语音单元最后一个字对应的句子单元的索引
|
||||
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
|
||||
return voice_sentences_result
|
||||
|
||||
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
|
||||
"""
|
||||
timestamp: 时间, 单位秒
|
||||
"""
|
||||
bytes_per_sample = bit_depth // 8
|
||||
return timestamp * sample_rate * bytes_per_sample * channels
|
||||
|
||||
|
||||
class AsrZhEvaluator(AsrEvaluator):
|
||||
"""
|
||||
中文的评估方式
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cut_zh_punc = ["······", "......", "。", ",", "?", "!", ";", ":"]
|
||||
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
|
||||
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
|
||||
|
||||
def evaluate(self, eval_result) -> Dict:
|
||||
self.query_count += 1
|
||||
self.voice_count += len(eval_result["voice"])
|
||||
|
||||
# 获取,标注结果 & 语音单元(非句子单元)
|
||||
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||||
# print("label_voice_sentences", label_voice_sentences)
|
||||
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||||
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||||
# print("voice_to_cut_info", voice_to_cut_info)
|
||||
# 获取,标注结果 & 句子单元
|
||||
label_sentences = voice_to_cut_info["cut_sentences"]
|
||||
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||||
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||||
# 标注结果 & 句子单元 & norm 操作
|
||||
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
|
||||
if log_mid_result:
|
||||
logger.info(f"label_sentences {label_sentences}")
|
||||
# print("label_sentences", label_sentences)
|
||||
|
||||
# 预测结果 & 句子单元
|
||||
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||||
# print("predict_sentences_raw", predict_sentences_raw)
|
||||
# 预测结果 & 句子单元 & norm 操作
|
||||
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||||
if log_mid_result:
|
||||
logger.info(f"predict_sentences {predict_sentences}")
|
||||
# print("predict_sentences", predict_sentences)
|
||||
|
||||
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
|
||||
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
|
||||
|
||||
# cer 计算
|
||||
cer_info = self.cer(label_sentences, predict_sentences)
|
||||
if log_mid_result:
|
||||
logger.info(f"cer_info {cer_info}")
|
||||
# print("cer_info", cer_info)
|
||||
self.one_minus_cer += cer_info["one_minus_cer"]
|
||||
self.token_count += cer_info["token_count"]
|
||||
|
||||
# 句子切分准召率
|
||||
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||||
|
||||
if log_mid_result:
|
||||
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||||
# print("cut_info", cut_info)
|
||||
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||||
self.miss_count += cut_info["miss_count"]
|
||||
self.more_count += cut_info["more_count"]
|
||||
self.cut_count += cut_info["cut_count"]
|
||||
self.rate += cut_info["rate"]
|
||||
|
||||
self.result.append(
|
||||
{
|
||||
"label_tokens": label_tokens,
|
||||
"predict_tokens": predict_tokens,
|
||||
"one_minus_cer": cer_info["one_minus_cer"],
|
||||
"token_count": cer_info["one_minus_cer"],
|
||||
"miss_count": cut_info["miss_count"],
|
||||
"more_count": cut_info["more_count"],
|
||||
"cut_count": cut_info["cut_count"],
|
||||
"rate": cut_info["rate"],
|
||||
}
|
||||
)
|
||||
|
||||
def cer(self, label_sentences, predict_sentences):
|
||||
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
|
||||
label_str = ''.join(label_sentences)
|
||||
r = asr_ter.calc_ter_speechio(pred_str, label_str)
|
||||
one_minus_cer = max(1.0 - r['ter'], 0)
|
||||
token_count = r['ref_all_token_cnt']
|
||||
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||||
|
||||
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||||
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
|
||||
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
|
||||
label_sentence_count = len(label_final_index_list)
|
||||
miss_count = len(label_final_index_list - pred_final_index_list)
|
||||
more_count = len(pred_final_index_list - label_final_index_list)
|
||||
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||||
return {
|
||||
"miss_count": miss_count,
|
||||
"more_count": more_count,
|
||||
"cut_count": label_sentence_count,
|
||||
"rate": rate,
|
||||
"label_final_index_list": label_final_index_list,
|
||||
"pred_final_index_list": pred_final_index_list,
|
||||
}
|
||||
|
||||
def _sentence_norm(self, sentence, tokenizer="word"):
|
||||
"""
|
||||
对句子进行 norm 操作
|
||||
"""
|
||||
from utils.speechio import textnorm_zh as textnorm
|
||||
|
||||
if tokenizer == "word":
|
||||
normalizer = textnorm.TextNorm(
|
||||
to_banjiao=True,
|
||||
to_upper=True,
|
||||
to_lower=False,
|
||||
remove_fillers=True,
|
||||
remove_erhua=False, # 这里同批量识别不同,改成了 False
|
||||
check_chars=False,
|
||||
remove_space=False,
|
||||
cc_mode="",
|
||||
)
|
||||
return normalizer(sentence)
|
||||
else:
|
||||
logger.error("tokenizer error, not support.")
|
||||
|
||||
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
|
||||
"""
|
||||
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||||
args:
|
||||
label: "今天的通话质量不错呀昨天的呢"
|
||||
predict: "今天的通话质量不错昨天呢星期"
|
||||
tokenizer: 分词方式
|
||||
return:
|
||||
label: ["今", "天", "的", "通", "话", "质", "量", "不", "错", "呀", "昨", "天", "的", "呢", None, None]
|
||||
predict: ["今", "天", "的", "通", "话", "质", "量", "不", "错", None, "昨", "天", None, "呢", "星", "期"]
|
||||
"""
|
||||
from utils.speechio import error_rate_zh as error_rate
|
||||
|
||||
if tokenizer == "char":
|
||||
alignment, score = error_rate.EditDistance(
|
||||
error_rate.tokenize_text(label_sentence, tokenizer),
|
||||
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||||
)
|
||||
label_tokens, pred_tokens = [], []
|
||||
for align in alignment:
|
||||
# print(align.__dict__)
|
||||
label_tokens.append(align.ref)
|
||||
pred_tokens.append(align.hyp)
|
||||
return (label_tokens, pred_tokens)
|
||||
else:
|
||||
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||||
|
||||
def _pred_data_transfer(self, predict_data, recv_time):
|
||||
"""
|
||||
predict_data = [
|
||||
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
|
||||
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
|
||||
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
|
||||
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
|
||||
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
|
||||
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
|
||||
]
|
||||
recv_time = [1, 3, 5, 6, 7, 8]
|
||||
|
||||
->
|
||||
|
||||
[
|
||||
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
|
||||
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
|
||||
]
|
||||
"""
|
||||
pred_sentence_info = []
|
||||
pred_sentence_index = 0
|
||||
for predict_item, recv_time_item in zip(predict_data, recv_time):
|
||||
if len(pred_sentence_info) == pred_sentence_index:
|
||||
pred_sentence_info.append([])
|
||||
pred_sentence_info[pred_sentence_index].append(
|
||||
{
|
||||
"text": predict_item["recoginition_results"]["text"],
|
||||
"time": recv_time_item,
|
||||
}
|
||||
)
|
||||
if predict_item["recoginition_results"]["final_result"]:
|
||||
pred_sentence_index += 1
|
||||
return pred_sentence_info
|
||||
|
||||
|
||||
class AsrEnEvaluator(AsrEvaluator):
|
||||
"""
|
||||
英文的评估方式
|
||||
"""
|
||||
|
||||
def evaluate(self, eval_result) -> Dict:
|
||||
self.query_count += 1
|
||||
self.voice_count += len(eval_result["voice"])
|
||||
|
||||
# 获取,标注结果 & 语音单元(非句子单元)
|
||||
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||||
# print("label_voice_sentences", label_voice_sentences)
|
||||
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||||
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||||
# print("voice_to_cut_info", voice_to_cut_info)
|
||||
# 获取,标注结果 & 句子单元
|
||||
label_sentences = voice_to_cut_info["cut_sentences"]
|
||||
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||||
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||||
# 标注结果 & 句子单元 & norm 操作
|
||||
label_sentences = self._sentence_list_norm(label_sentences)
|
||||
# [self._sentence_norm(sentence) for sentence in label_sentences]
|
||||
# print("label_sentences", label_sentences)
|
||||
if log_mid_result:
|
||||
logger.info(f"label_sentences {label_sentences}")
|
||||
|
||||
# 预测结果 & 句子单元
|
||||
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||||
# print("predict_sentences_raw", predict_sentences_raw)
|
||||
# 预测结果 & 句子单元 & norm 操作
|
||||
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
|
||||
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||||
# print("predict_sentences", predict_sentences)
|
||||
if log_mid_result:
|
||||
logger.info(f"predict_sentences {predict_sentences}")
|
||||
|
||||
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
|
||||
# print(label_tokens)
|
||||
# print(predict_tokens)
|
||||
|
||||
# cer 计算
|
||||
cer_info = self.cer(label_tokens, predict_tokens)
|
||||
# print("cer_info", cer_info)
|
||||
if log_mid_result:
|
||||
logger.info(f"cer_info {cer_info}")
|
||||
self.one_minus_cer += cer_info["one_minus_cer"]
|
||||
self.token_count += cer_info["token_count"]
|
||||
|
||||
# 句子切分准召率
|
||||
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||||
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||||
# print("cut_info", cut_info)
|
||||
if log_mid_result:
|
||||
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||||
self.miss_count += cut_info["miss_count"]
|
||||
self.more_count += cut_info["more_count"]
|
||||
self.cut_count += cut_info["cut_count"]
|
||||
self.rate += cut_info["rate"]
|
||||
|
||||
self.result.append(
|
||||
{
|
||||
"label_tokens": label_tokens,
|
||||
"predict_tokens": predict_tokens,
|
||||
"one_minus_cer": cer_info["one_minus_cer"],
|
||||
"token_count": cer_info["one_minus_cer"],
|
||||
"miss_count": cut_info["miss_count"],
|
||||
"more_count": cut_info["more_count"],
|
||||
"cut_count": cut_info["cut_count"],
|
||||
"rate": cut_info["rate"],
|
||||
}
|
||||
)
|
||||
|
||||
def cer(self, label_tokens, predict_tokens):
|
||||
s, d, i, c = 0, 0, 0, 0
|
||||
for label_token, predict_token in zip(label_tokens, predict_tokens):
|
||||
if label_token == predict_token:
|
||||
c += 1
|
||||
elif predict_token is None:
|
||||
d += 1
|
||||
elif label_token is None:
|
||||
i += 1
|
||||
else:
|
||||
s += 1
|
||||
cer = (s + d + i) / (s + d + c)
|
||||
one_minus_cer = max(1.0 - cer, 0)
|
||||
token_count = s + d + c
|
||||
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||||
|
||||
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||||
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
|
||||
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
|
||||
label_sentence_count = len(label_final_index_list)
|
||||
miss_count = len(label_final_index_list - pred_final_index_list)
|
||||
more_count = len(pred_final_index_list - label_final_index_list)
|
||||
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||||
return {
|
||||
"miss_count": miss_count,
|
||||
"more_count": more_count,
|
||||
"cut_count": label_sentence_count,
|
||||
"rate": rate,
|
||||
"label_final_index_list": label_final_index_list,
|
||||
"pred_final_index_list": pred_final_index_list,
|
||||
}
|
||||
|
||||
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
|
||||
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
with open('./predict.txt', 'w', encoding='utf-8') as fp:
|
||||
for idx, sentence in enumerate(sentence_list):
|
||||
fp.write('%s\t%s\n' % (idx, sentence))
|
||||
subprocess.run(
|
||||
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
sentence_norm = []
|
||||
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
|
||||
for line in fp.readlines():
|
||||
line_split_result = line.strip().split('\t', 1)
|
||||
if len(line_split_result) >= 2:
|
||||
sentence_norm.append(line_split_result[1])
|
||||
# 有可能没有 norm 后就没了
|
||||
return sentence_norm
|
||||
|
||||
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
|
||||
"""
|
||||
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||||
args:
|
||||
label: "HELLO WORLD ARE U OK YEP"
|
||||
predict: "HELLO WORLD U ARE U OK YEP"
|
||||
tokenizer: 分词方式
|
||||
return:
|
||||
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
|
||||
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
|
||||
"""
|
||||
from utils.speechio import error_rate_zh as error_rate
|
||||
|
||||
if tokenizer == "whitespace":
|
||||
alignment, score = error_rate.EditDistance(
|
||||
error_rate.tokenize_text(label_sentence, tokenizer),
|
||||
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||||
)
|
||||
label_tokens, pred_tokens = [], []
|
||||
for align in alignment:
|
||||
label_tokens.append(align.ref)
|
||||
pred_tokens.append(align.hyp)
|
||||
return (label_tokens, pred_tokens)
|
||||
else:
|
||||
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||||
|
||||
def post_evaluate(self) -> Dict:
|
||||
pass
|
||||
Reference in New Issue
Block a user