446 lines
20 KiB
Python
446 lines
20 KiB
Python
import os
|
||
import subprocess
|
||
from collections import defaultdict
|
||
from typing import Dict, List
|
||
|
||
from utils import asr_ter
|
||
from utils.logger import logger
|
||
|
||
log_mid_result = int(os.getenv("log", 0)) == 1
|
||
|
||
|
||
class AsrEvaluator:
|
||
def __init__(self) -> None:
|
||
self.query_count = 0 # query 数目(语音数目)
|
||
self.voice_count = 0 # 有开始和结束时间的语音条数(用于 RTF 计算)
|
||
self.cut_punc = [] # 切分标点符号,需要注意切分的时候根据列表中的顺序进行切分,比如 ... 应该放到 . 之前。
|
||
# cer 属性
|
||
self.one_minus_cer = 0 # 每个 query 的 1 - cer 和
|
||
self.token_count = 0 # 每个 query 的字数/词数和
|
||
# 句子切分率属性
|
||
self.miss_count = 0 # 每个 query miss-count 和
|
||
self.more_count = 0 # 每个 query more-count 和
|
||
self.cut_count = 0 # 每个 query cut-count 和
|
||
self.rate = 0 # 每个 query 的 cut-rate 和
|
||
# detail case
|
||
self.result = []
|
||
|
||
def evaluate(self, eval_result):
|
||
pass
|
||
|
||
def post_evaluate(self):
|
||
pass
|
||
|
||
def gen_result(self) -> Dict:
|
||
output_result = dict()
|
||
output_result["query_count"] = self.query_count
|
||
output_result["voice_count"] = self.voice_count
|
||
output_result["token_cnt"] = self.token_count
|
||
output_result["one_minus_cer"] = self.one_minus_cer
|
||
output_result["one_minus_cer_metrics"] = self.one_minus_cer / self.query_count
|
||
output_result["miss_count"] = self.miss_count
|
||
output_result["more_count"] = self.more_count
|
||
output_result["cut_count"] = self.cut_count
|
||
output_result["cut_rate"] = self.rate
|
||
output_result["cut_rate_metrics"] = self.rate / self.query_count
|
||
output_result["rtf"] = self.rtf
|
||
output_result["rtf_end"] = self.rtf_end
|
||
output_result["rtf_metrics"] = self.rtf / self.voice_count
|
||
output_result["rtf_end_metrics"] = self.rtf_end / self.voice_count
|
||
|
||
detail_case = self.result
|
||
return output_result, detail_case
|
||
|
||
def _get_predict_final_sentences(self, predict_data: List[Dict]) -> List[str]:
|
||
"""
|
||
获取 predict data 数据,然后将其中 final 的句子拿出来,放到列表里。
|
||
"""
|
||
return [
|
||
item["recoginition_results"]["text"]
|
||
for item in predict_data
|
||
if item["recoginition_results"]["final_result"]
|
||
]
|
||
|
||
def _sentence_final_index(self, sentences: List[str], tokens: List[str], tokenizer="word") -> List[int]:
|
||
"""
|
||
获取 sentence 结束的字对应的 token 索引值。
|
||
"""
|
||
token_index_list = []
|
||
token_idx = 0
|
||
for sentence in sentences:
|
||
for token in Tokenizer.tokenize(sentence, tokenizer):
|
||
if token not in tokens:
|
||
continue
|
||
while tokens[token_idx] != token:
|
||
token_idx += 1
|
||
token_index_list.append(token_idx)
|
||
return token_index_list
|
||
|
||
def _voice_to_cut_sentence(self, voice_sentences: List[str]) -> Dict:
|
||
"""
|
||
将数据集的语音片段转换为最小切分单元列表。
|
||
使用 cut_punc 中的所有 punc 进行依次切分,最后去除掉完全空的内容
|
||
示例:
|
||
["你好,你好呀", "你好,我在写抽象的代码逻辑"]
|
||
->
|
||
cut_sentences: ["你好", "你好呀", "你好", "我在写抽象的代码逻辑"]
|
||
cut_sentence_index_list: [1, 3] ("你好呀" 对应 1-idx, "我在写抽象的代码逻辑" 对应 3-idx)
|
||
"""
|
||
voice_sentences_result = defaultdict(list)
|
||
for voice_sentence in voice_sentences:
|
||
sentence_list = [voice_sentence]
|
||
sentence_tmp_list = []
|
||
for punc in self.cut_punc:
|
||
for sentence in sentence_list:
|
||
sentence_tmp_list.extend(sentence.split(punc))
|
||
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||
sentence_list = [item for item in sentence_list if item]
|
||
# 切分后的句子单元
|
||
voice_sentences_result["cut_sentences"].extend(sentence_list)
|
||
# 每个语音单元最后一个字对应的句子单元的索引
|
||
voice_sentences_result["cut_sentence_index_list"].append(len(voice_sentences_result["cut_sentences"]) - 1)
|
||
return voice_sentences_result
|
||
|
||
def _voice_bytes_index(self, timestamp, sample_rate=16000, bit_depth=16, channels=1):
|
||
"""
|
||
timestamp: 时间, 单位秒
|
||
"""
|
||
bytes_per_sample = bit_depth // 8
|
||
return timestamp * sample_rate * bytes_per_sample * channels
|
||
|
||
|
||
class AsrZhEvaluator(AsrEvaluator):
|
||
"""
|
||
中文的评估方式
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.cut_zh_punc = ["······", "......", "。", ",", "?", "!", ";", ":"]
|
||
self.cut_en_punc = ["...", ".", ",", "?", "!", ";", ":"]
|
||
self.cut_punc = self.cut_zh_punc + self.cut_en_punc
|
||
|
||
def evaluate(self, eval_result) -> Dict:
|
||
self.query_count += 1
|
||
self.voice_count += len(eval_result["voice"])
|
||
|
||
# 获取,标注结果 & 语音单元(非句子单元)
|
||
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||
# print("label_voice_sentences", label_voice_sentences)
|
||
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||
# print("voice_to_cut_info", voice_to_cut_info)
|
||
# 获取,标注结果 & 句子单元
|
||
label_sentences = voice_to_cut_info["cut_sentences"]
|
||
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||
# 标注结果 & 句子单元 & norm 操作
|
||
label_sentences = [self._sentence_norm(sentence) for sentence in label_sentences]
|
||
if log_mid_result:
|
||
logger.info(f"label_sentences {label_sentences}")
|
||
# print("label_sentences", label_sentences)
|
||
|
||
# 预测结果 & 句子单元
|
||
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||
# print("predict_sentences_raw", predict_sentences_raw)
|
||
# 预测结果 & 句子单元 & norm 操作
|
||
predict_sentences = [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||
if log_mid_result:
|
||
logger.info(f"predict_sentences {predict_sentences}")
|
||
# print("predict_sentences", predict_sentences)
|
||
|
||
# 基于最小编辑距离进行 token 匹配,获得匹配后的 token 列表
|
||
label_tokens, predict_tokens = self._sentence_transfer("".join(label_sentences), "".join(predict_sentences))
|
||
|
||
# cer 计算
|
||
cer_info = self.cer(label_sentences, predict_sentences)
|
||
if log_mid_result:
|
||
logger.info(f"cer_info {cer_info}")
|
||
# print("cer_info", cer_info)
|
||
self.one_minus_cer += cer_info["one_minus_cer"]
|
||
self.token_count += cer_info["token_count"]
|
||
|
||
# 句子切分准召率
|
||
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||
|
||
if log_mid_result:
|
||
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||
# print("cut_info", cut_info)
|
||
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||
self.miss_count += cut_info["miss_count"]
|
||
self.more_count += cut_info["more_count"]
|
||
self.cut_count += cut_info["cut_count"]
|
||
self.rate += cut_info["rate"]
|
||
|
||
self.result.append(
|
||
{
|
||
"label_tokens": label_tokens,
|
||
"predict_tokens": predict_tokens,
|
||
"one_minus_cer": cer_info["one_minus_cer"],
|
||
"token_count": cer_info["one_minus_cer"],
|
||
"miss_count": cut_info["miss_count"],
|
||
"more_count": cut_info["more_count"],
|
||
"cut_count": cut_info["cut_count"],
|
||
"rate": cut_info["rate"],
|
||
}
|
||
)
|
||
|
||
def cer(self, label_sentences, predict_sentences):
|
||
pred_str = ''.join(predict_sentences) if predict_sentences is not None else ''
|
||
label_str = ''.join(label_sentences)
|
||
r = asr_ter.calc_ter_speechio(pred_str, label_str)
|
||
one_minus_cer = max(1.0 - r['ter'], 0)
|
||
token_count = r['ref_all_token_cnt']
|
||
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||
|
||
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens))
|
||
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens))
|
||
label_sentence_count = len(label_final_index_list)
|
||
miss_count = len(label_final_index_list - pred_final_index_list)
|
||
more_count = len(pred_final_index_list - label_final_index_list)
|
||
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||
return {
|
||
"miss_count": miss_count,
|
||
"more_count": more_count,
|
||
"cut_count": label_sentence_count,
|
||
"rate": rate,
|
||
"label_final_index_list": label_final_index_list,
|
||
"pred_final_index_list": pred_final_index_list,
|
||
}
|
||
|
||
def _sentence_norm(self, sentence, tokenizer="word"):
|
||
"""
|
||
对句子进行 norm 操作
|
||
"""
|
||
from utils.speechio import textnorm_zh as textnorm
|
||
|
||
if tokenizer == "word":
|
||
normalizer = textnorm.TextNorm(
|
||
to_banjiao=True,
|
||
to_upper=True,
|
||
to_lower=False,
|
||
remove_fillers=True,
|
||
remove_erhua=False, # 这里同批量识别不同,改成了 False
|
||
check_chars=False,
|
||
remove_space=False,
|
||
cc_mode="",
|
||
)
|
||
return normalizer(sentence)
|
||
else:
|
||
logger.error("tokenizer error, not support.")
|
||
|
||
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="char"):
|
||
"""
|
||
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||
args:
|
||
label: "今天的通话质量不错呀昨天的呢"
|
||
predict: "今天的通话质量不错昨天呢星期"
|
||
tokenizer: 分词方式
|
||
return:
|
||
label: ["今", "天", "的", "通", "话", "质", "量", "不", "错", "呀", "昨", "天", "的", "呢", None, None]
|
||
predict: ["今", "天", "的", "通", "话", "质", "量", "不", "错", None, "昨", "天", None, "呢", "星", "期"]
|
||
"""
|
||
from utils.speechio import error_rate_zh as error_rate
|
||
|
||
if tokenizer == "char":
|
||
alignment, score = error_rate.EditDistance(
|
||
error_rate.tokenize_text(label_sentence, tokenizer),
|
||
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||
)
|
||
label_tokens, pred_tokens = [], []
|
||
for align in alignment:
|
||
# print(align.__dict__)
|
||
label_tokens.append(align.ref)
|
||
pred_tokens.append(align.hyp)
|
||
return (label_tokens, pred_tokens)
|
||
else:
|
||
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||
|
||
def _pred_data_transfer(self, predict_data, recv_time):
|
||
"""
|
||
predict_data = [
|
||
{"recoginition_results": {"text": "1", "final_result": False, "para_seq": 0}},
|
||
{"recoginition_results": {"text": "12", "final_result": False, "para_seq": 0}},
|
||
{"recoginition_results": {"text": "123", "final_result": True, "para_seq": 0}},
|
||
{"recoginition_results": {"text": "4", "final_result": False, "para_seq": 0}},
|
||
{"recoginition_results": {"text": "45", "final_result": False, "para_seq": 0}},
|
||
{"recoginition_results": {"text": "456", "final_result": True, "para_seq": 0}},
|
||
]
|
||
recv_time = [1, 3, 5, 6, 7, 8]
|
||
|
||
->
|
||
|
||
[
|
||
[{'text': '1', 'time': 1}, {'text': '12', 'time': 3}, {'text': '123', 'time': 5}],
|
||
[{'text': '4', 'time': 6}, {'text': '45', 'time': 7}, {'text': '456', 'time': 8}],
|
||
]
|
||
"""
|
||
pred_sentence_info = []
|
||
pred_sentence_index = 0
|
||
for predict_item, recv_time_item in zip(predict_data, recv_time):
|
||
if len(pred_sentence_info) == pred_sentence_index:
|
||
pred_sentence_info.append([])
|
||
pred_sentence_info[pred_sentence_index].append(
|
||
{
|
||
"text": predict_item["recoginition_results"]["text"],
|
||
"time": recv_time_item,
|
||
}
|
||
)
|
||
if predict_item["recoginition_results"]["final_result"]:
|
||
pred_sentence_index += 1
|
||
return pred_sentence_info
|
||
|
||
|
||
class AsrEnEvaluator(AsrEvaluator):
|
||
"""
|
||
英文的评估方式
|
||
"""
|
||
|
||
def evaluate(self, eval_result) -> Dict:
|
||
self.query_count += 1
|
||
self.voice_count += len(eval_result["voice"])
|
||
|
||
# 获取,标注结果 & 语音单元(非句子单元)
|
||
label_voice_sentences = [item["answer"] for item in eval_result["voice"]]
|
||
# print("label_voice_sentences", label_voice_sentences)
|
||
# 获取,标注结果 & 语音单元 -> 句子单元的转换情况
|
||
voice_to_cut_info = self._voice_to_cut_sentence(label_voice_sentences)
|
||
# print("voice_to_cut_info", voice_to_cut_info)
|
||
# 获取,标注结果 & 句子单元
|
||
label_sentences = voice_to_cut_info["cut_sentences"]
|
||
# 获取,标注结果 & 语音单元 -> 句子单元的映射关系,每个语音单元最后一个字对应的句子单元的索引
|
||
cut_sentence_index_list = voice_to_cut_info["cut_sentence_index_list"]
|
||
# 标注结果 & 句子单元 & norm 操作
|
||
label_sentences = self._sentence_list_norm(label_sentences)
|
||
# [self._sentence_norm(sentence) for sentence in label_sentences]
|
||
# print("label_sentences", label_sentences)
|
||
if log_mid_result:
|
||
logger.info(f"label_sentences {label_sentences}")
|
||
|
||
# 预测结果 & 句子单元
|
||
predict_sentences_raw = self._get_predict_final_sentences(eval_result["predict_data"])
|
||
# print("predict_sentences_raw", predict_sentences_raw)
|
||
# 预测结果 & 句子单元 & norm 操作
|
||
predict_sentences = self._sentence_list_norm(predict_sentences_raw)
|
||
# [self._sentence_norm(sentence) for sentence in predict_sentences_raw]
|
||
# print("predict_sentences", predict_sentences)
|
||
if log_mid_result:
|
||
logger.info(f"predict_sentences {predict_sentences}")
|
||
|
||
label_tokens, predict_tokens = self._sentence_transfer(" ".join(label_sentences), " ".join(predict_sentences))
|
||
# print(label_tokens)
|
||
# print(predict_tokens)
|
||
|
||
# cer 计算
|
||
cer_info = self.cer(label_tokens, predict_tokens)
|
||
# print("cer_info", cer_info)
|
||
if log_mid_result:
|
||
logger.info(f"cer_info {cer_info}")
|
||
self.one_minus_cer += cer_info["one_minus_cer"]
|
||
self.token_count += cer_info["token_count"]
|
||
|
||
# 句子切分准召率
|
||
cut_info = self.cut_rate(label_sentences, predict_sentences, label_tokens, predict_tokens)
|
||
# print(cut_info["miss_count"], cut_info["more_count"], cut_info["rate"])
|
||
# print("cut_info", cut_info)
|
||
if log_mid_result:
|
||
logger.info(f"{cut_info['miss_count']}, {cut_info['more_count']}, {cut_info['rate']}")
|
||
self.miss_count += cut_info["miss_count"]
|
||
self.more_count += cut_info["more_count"]
|
||
self.cut_count += cut_info["cut_count"]
|
||
self.rate += cut_info["rate"]
|
||
|
||
self.result.append(
|
||
{
|
||
"label_tokens": label_tokens,
|
||
"predict_tokens": predict_tokens,
|
||
"one_minus_cer": cer_info["one_minus_cer"],
|
||
"token_count": cer_info["one_minus_cer"],
|
||
"miss_count": cut_info["miss_count"],
|
||
"more_count": cut_info["more_count"],
|
||
"cut_count": cut_info["cut_count"],
|
||
"rate": cut_info["rate"],
|
||
}
|
||
)
|
||
|
||
def cer(self, label_tokens, predict_tokens):
|
||
s, d, i, c = 0, 0, 0, 0
|
||
for label_token, predict_token in zip(label_tokens, predict_tokens):
|
||
if label_token == predict_token:
|
||
c += 1
|
||
elif predict_token is None:
|
||
d += 1
|
||
elif label_token is None:
|
||
i += 1
|
||
else:
|
||
s += 1
|
||
cer = (s + d + i) / (s + d + c)
|
||
one_minus_cer = max(1.0 - cer, 0)
|
||
token_count = s + d + c
|
||
return {"one_minus_cer": one_minus_cer, "token_count": token_count}
|
||
|
||
def cut_rate(self, label_sentences, predict_sentences, label_tokens, predict_tokens):
|
||
label_final_index_list = set(self._sentence_final_index(label_sentences, label_tokens, "whitespace"))
|
||
pred_final_index_list = set(self._sentence_final_index(predict_sentences, predict_tokens, "whitespace"))
|
||
label_sentence_count = len(label_final_index_list)
|
||
miss_count = len(label_final_index_list - pred_final_index_list)
|
||
more_count = len(pred_final_index_list - label_final_index_list)
|
||
rate = max(1 - (miss_count + more_count * 2) / label_sentence_count, 0)
|
||
return {
|
||
"miss_count": miss_count,
|
||
"more_count": more_count,
|
||
"cut_count": label_sentence_count,
|
||
"rate": rate,
|
||
"label_final_index_list": label_final_index_list,
|
||
"pred_final_index_list": pred_final_index_list,
|
||
}
|
||
|
||
def _sentence_list_norm(self, sentence_list, tokenizer="whitespace"):
|
||
pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
with open('./predict.txt', 'w', encoding='utf-8') as fp:
|
||
for idx, sentence in enumerate(sentence_list):
|
||
fp.write('%s\t%s\n' % (idx, sentence))
|
||
subprocess.run(
|
||
f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
|
||
shell=True,
|
||
check=True,
|
||
)
|
||
sentence_norm = []
|
||
with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
|
||
for line in fp.readlines():
|
||
line_split_result = line.strip().split('\t', 1)
|
||
if len(line_split_result) >= 2:
|
||
sentence_norm.append(line_split_result[1])
|
||
# 有可能没有 norm 后就没了
|
||
return sentence_norm
|
||
|
||
def _sentence_transfer(self, label_sentence: str, predict_sentence: str, tokenizer="whitespace"):
|
||
"""
|
||
基于最小编辑距离,将 label 和 predict 进行字的位置匹配,并生成转换后的结果
|
||
args:
|
||
label: "HELLO WORLD ARE U OK YEP"
|
||
predict: "HELLO WORLD U ARE U OK YEP"
|
||
tokenizer: 分词方式
|
||
return:
|
||
label: ["HELLO", "WORLD", None, "ARE", "U", "OK", "YEP"]
|
||
predict: ["HELLO", "WORLD", "U", "ARE", "U", "OK", "YEP"]
|
||
"""
|
||
from utils.speechio import error_rate_zh as error_rate
|
||
|
||
if tokenizer == "whitespace":
|
||
alignment, score = error_rate.EditDistance(
|
||
error_rate.tokenize_text(label_sentence, tokenizer),
|
||
error_rate.tokenize_text(predict_sentence, tokenizer),
|
||
)
|
||
label_tokens, pred_tokens = [], []
|
||
for align in alignment:
|
||
label_tokens.append(align.ref)
|
||
pred_tokens.append(align.hyp)
|
||
return (label_tokens, pred_tokens)
|
||
else:
|
||
logger.error("tokenizer 出错了,暂时不支持其它的")
|
||
|
||
def post_evaluate(self) -> Dict:
|
||
pass
|