115 lines
3.8 KiB
Python
115 lines
3.8 KiB
Python
import gc
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import zipfile
|
|
|
|
import yaml
|
|
from schemas.context import ASRContext
|
|
from utils.client import Client
|
|
from utils.evaluator import BaseEvaluator
|
|
from utils.logger import logger
|
|
from utils.service import register_sut
|
|
|
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
|
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
|
|
|
|
|
def main():
|
|
logger.info("执行……")
|
|
|
|
dataset_filepath = os.getenv(
|
|
"DATASET_FILEPATH",
|
|
"./tests/resources/en.zip",
|
|
)
|
|
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
|
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
|
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
|
detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
|
|
|
resource_name = os.getenv("BENCHMARK_NAME")
|
|
|
|
# 提交配置 & 启动被测服务
|
|
if os.getenv("DATASET_FILEPATH", ""):
|
|
from utils.helm import resource_check
|
|
|
|
with open(submit_config_filepath, "r") as fp:
|
|
st_config = yaml.safe_load(fp)
|
|
st_config["values"] = resource_check(st_config.get("values", {}))
|
|
if 'docker_images' in st_config:
|
|
sut_url = "ws://172.26.1.75:9827"
|
|
os.environ['test'] = '1'
|
|
elif 'docker_image' in st_config:
|
|
sut_url = register_sut(st_config, resource_name)
|
|
elif UNIT_TEST:
|
|
sut_url = "ws://172.27.231.36:80"
|
|
else:
|
|
logger.error("config 配置错误,没有 docker_image")
|
|
os._exit(1)
|
|
else:
|
|
os.environ['test'] = '1'
|
|
sut_url = "ws://172.27.231.36:80"
|
|
if UNIT_TEST:
|
|
exit(0)
|
|
|
|
"""
|
|
# 数据集处理
|
|
local_dataset_path = "./dataset"
|
|
os.makedirs(local_dataset_path, exist_ok=True)
|
|
with zipfile.ZipFile(dataset_filepath) as zf:
|
|
zf.extractall(local_dataset_path)
|
|
config_path = os.path.join(local_dataset_path, "data.yaml")
|
|
with open(config_path, "r") as fp:
|
|
dataset_config = yaml.safe_load(fp)
|
|
|
|
# 数据集信息
|
|
dataset_global_config = dataset_config.get("global", {})
|
|
dataset_query = dataset_config.get("query_data", {})
|
|
|
|
evaluator = BaseEvaluator()
|
|
|
|
# 开始预测
|
|
for idx, query_item in enumerate(dataset_query):
|
|
gc.collect()
|
|
logger.info(f"开始执行 {idx} 条数据")
|
|
|
|
context = ASRContext(**dataset_global_config)
|
|
context.lang = query_item.get("lang", context.lang)
|
|
context.file_path = os.path.join(local_dataset_path, query_item["file"])
|
|
# context.audio_length = query_item["audio_length"]
|
|
|
|
interactions = Client(sut_url, context).action()
|
|
context.append_labels(query_item["voice"])
|
|
context.append_preds(
|
|
interactions["predict_data"],
|
|
interactions["send_time"],
|
|
interactions["recv_time"],
|
|
)
|
|
context.fail = interactions["fail"]
|
|
if IN_TEST:
|
|
with open('output.txt', 'w') as fp:
|
|
original_stdout = sys.stdout
|
|
sys.stdout = fp
|
|
print(context)
|
|
sys.stdout = original_stdout
|
|
evaluator.evaluate(context)
|
|
detail_case = evaluator.gen_detail_case()
|
|
with open(detail_cases_filepath, "a") as fp:
|
|
fp.write(json.dumps(detail_case.to_dict(), ensure_ascii=False) + "\n")
|
|
time.sleep(4)
|
|
|
|
evaluator.post_evaluate()
|
|
output_result = evaluator.gen_result()
|
|
# print(evaluator.__dict__)
|
|
logger.info("执行完成. Result = {output_result}")
|
|
|
|
with open(result_filepath, "w") as fp:
|
|
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
|
with open(bad_cases_filepath, "w") as fp:
|
|
fp.write("当前榜单不存在 Bad Case\n")
|
|
"""
|
|
|
|
if __name__ == "__main__":
|
|
main()
|