update
This commit is contained in:
53
scripts/check_dataset_time.py
Normal file
53
scripts/check_dataset_time.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def main(dataset_dir):
|
||||
dirs = os.listdir(dataset_dir)
|
||||
dirs = list(
|
||||
filter(lambda x: os.path.isdir(os.path.join(dataset_dir, x)), dirs)
|
||||
)
|
||||
|
||||
problem_dirs = set()
|
||||
problem_count = defaultdict(int)
|
||||
for dir in dirs:
|
||||
with open(os.path.join(dataset_dir, dir, "data.yaml"), "r") as f:
|
||||
data = yaml.full_load(f)
|
||||
for query_i, query in enumerate(data["query_data"]):
|
||||
voices = sorted(query["voice"], key=lambda x: x["start"])
|
||||
if voices != query["voice"]:
|
||||
print("-----", dir)
|
||||
if voices[0]["start"] > voices[0]["end"]:
|
||||
print(
|
||||
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||
% (dir, query_i, 0, voices[0]["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
for voice_i in range(1, len(voices)):
|
||||
voice = voices[voice_i]
|
||||
if voice["start"] > voice["end"]:
|
||||
print(
|
||||
"err1: %s 第%s个query的第%d个voice的start大于end: %s"
|
||||
% (dir, query_i, voice_i, voice["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
if voice["start"] < voices[voice_i - 1]["end"]:
|
||||
print(
|
||||
"err2: %s 第%s个query的第%d个voice的start小于前一个voice的end: %s"
|
||||
% (dir, query_i, voice_i, voice["answer"])
|
||||
)
|
||||
problem_dirs.add(dir)
|
||||
problem_count[dir] += 1
|
||||
print(len(dirs))
|
||||
print(problem_dirs)
|
||||
print(problem_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("指定 测试数据集文件夹")
|
||||
sys.exit(1)
|
||||
main(sys.argv[1])
|
||||
108
scripts/convert_callback_dataset.py
Normal file
108
scripts/convert_callback_dataset.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import yaml
|
||||
|
||||
"""
|
||||
target
|
||||
{
|
||||
"global": {
|
||||
"lang": ""
|
||||
},
|
||||
"query_data": [
|
||||
"file": "",
|
||||
"duration": 2.0,
|
||||
"voice": [
|
||||
{
|
||||
"answer": "",
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def situation_a(meta, dataset_folder, output_folder):
|
||||
"""
|
||||
{
|
||||
"combined": {
|
||||
"en": [
|
||||
{
|
||||
"wav": "*.wav",
|
||||
"transcriptions": [
|
||||
{
|
||||
"text": "",
|
||||
"start": 0.0,
|
||||
"end": 1.0
|
||||
}
|
||||
],
|
||||
"duration": 2.0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
"""
|
||||
meta = meta["combined"]
|
||||
|
||||
for lang, arr in meta.items():
|
||||
print("processing", lang)
|
||||
assert len(lang) == 2
|
||||
lang_folder = os.path.join(output_folder, lang)
|
||||
os.makedirs(lang_folder, exist_ok=True)
|
||||
data = {"global": {"lang": lang}, "query_data": []}
|
||||
query_data = data["query_data"]
|
||||
for item in arr:
|
||||
os.makedirs(
|
||||
os.path.join(lang_folder, os.path.dirname(item["wav"])),
|
||||
exist_ok=True,
|
||||
)
|
||||
mp3_file = item["wav"][:-4] + ".mp3"
|
||||
shutil.copyfile(
|
||||
os.path.join(dataset_folder, mp3_file),
|
||||
os.path.join(lang_folder, mp3_file),
|
||||
)
|
||||
query_data_item = {
|
||||
"file": mp3_file,
|
||||
"duration": float(item["duration"]),
|
||||
"voice": [],
|
||||
}
|
||||
query_data.append(query_data_item)
|
||||
voice = query_data_item["voice"]
|
||||
for v in item["transcriptions"]:
|
||||
voice.append(
|
||||
{
|
||||
"answer": v["text"],
|
||||
"start": float(v["start"]),
|
||||
"end": float(v["end"]),
|
||||
}
|
||||
)
|
||||
with open(os.path.join(lang_folder, "data.yaml"), "w") as f:
|
||||
yaml.dump(data, f, indent=2, allow_unicode=True, encoding="utf-8")
|
||||
with zipfile.ZipFile(
|
||||
os.path.join(output_folder, lang + ".zip"), "w"
|
||||
) as ziper:
|
||||
dirname = lang_folder
|
||||
for path, _, files in os.walk(dirname):
|
||||
for file in files:
|
||||
ziper.write(
|
||||
os.path.join(path, file),
|
||||
os.path.join(path[len(dirname) :], file),
|
||||
zipfile.ZIP_DEFLATED,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 3:
|
||||
print("指定 数据集文件夹路径 输出路径")
|
||||
sys.exit(1)
|
||||
dataset_folder = sys.argv[1]
|
||||
output_folder = sys.argv[2]
|
||||
|
||||
with open(os.path.join(dataset_folder, "meta.json")) as f:
|
||||
meta = json.load(f)
|
||||
situation_a(meta, dataset_folder, output_folder)
|
||||
56
scripts/debug_detailcase.py
Normal file
56
scripts/debug_detailcase.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
from schemas.dataset import QueryData
|
||||
from schemas.stream import StreamDataModel
|
||||
from utils.evaluator_plus import evaluate_editops
|
||||
|
||||
|
||||
def main(detailcase_file: str):
|
||||
with open(detailcase_file) as f:
|
||||
d = json.load(f)[0]
|
||||
preds = d["preds"]
|
||||
preds = list(map(lambda x: StreamDataModel(**x), preds))
|
||||
preds = list(filter(lambda x: x.final_result, preds))
|
||||
label = d["label"]
|
||||
label = QueryData(**label)
|
||||
print(evaluate_editops(label, preds))
|
||||
|
||||
|
||||
def evaluate_from_record(detailcase_file: str, record_path: str):
|
||||
with open(detailcase_file) as f:
|
||||
d = json.load(f)[0]
|
||||
label = d["label"]
|
||||
label = QueryData(**label)
|
||||
with open(record_path) as f:
|
||||
record = json.load(f)
|
||||
tokens_pred = record["tokens_pred"]
|
||||
tokens_label = record["tokens_label"]
|
||||
recognition_results = record["recognition_results"]
|
||||
recognition_results = list(
|
||||
map(lambda x: StreamDataModel(**x), recognition_results)
|
||||
)
|
||||
a, b = [], []
|
||||
for i, rr in enumerate(recognition_results):
|
||||
if rr.final_result:
|
||||
a.append(tokens_pred[i])
|
||||
b.append(rr)
|
||||
tokens_pred = a
|
||||
recognition_results = b
|
||||
|
||||
print(
|
||||
evaluate_editops(
|
||||
label,
|
||||
recognition_results,
|
||||
tokens_pred,
|
||||
tokens_label,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("请指定 detailcase 文件路径")
|
||||
sys.exit(1)
|
||||
main(sys.argv[1])
|
||||
# evaluate_from_record(sys.argv[1], sys.argv[2])
|
||||
Reference in New Issue
Block a user