update
This commit is contained in:
50
utils/metrics_plus.py
Normal file
50
utils/metrics_plus.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import List
|
||||
|
||||
from utils.tokenizer import TokenizerType
|
||||
|
||||
|
||||
def replace_general_punc(
|
||||
sentences: List[str], tokenizer: TokenizerType
|
||||
) -> List[str]:
|
||||
"""代替原来的函数 utils.metrics.cut_sentence"""
|
||||
general_puncs = [
|
||||
"······",
|
||||
"......",
|
||||
"。",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
";",
|
||||
":",
|
||||
"...",
|
||||
".",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
";",
|
||||
":",
|
||||
]
|
||||
if tokenizer == TokenizerType.whitespace:
|
||||
replacer = " "
|
||||
else:
|
||||
replacer = ""
|
||||
trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer))
|
||||
ret_sentences = [""] * len(sentences)
|
||||
for i, sentence in enumerate(sentences):
|
||||
sentence = sentence.translate(trans)
|
||||
sentence = sentence.strip()
|
||||
sentence = sentence.lower()
|
||||
ret_sentences[i] = sentence
|
||||
return ret_sentences
|
||||
|
||||
|
||||
def distance_point_line(
|
||||
point: float, line_start: float, line_end: float
|
||||
) -> float:
|
||||
"""计算点到直线的距离"""
|
||||
if line_start <= point <= line_end:
|
||||
return 0
|
||||
if point < line_start:
|
||||
return abs(point - line_start)
|
||||
else:
|
||||
return abs(point - line_end)
|
||||
Reference in New Issue
Block a user