758 lines
27 KiB
Python
758 lines
27 KiB
Python
import atexit
|
||
import concurrent.futures
|
||
import fcntl
|
||
import gc
|
||
import glob
|
||
import json
|
||
import os
|
||
import random
|
||
import signal
|
||
import sys
|
||
import tempfile
|
||
import threading
|
||
import time
|
||
import zipfile
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
import yaml
|
||
from fabric import Connection
|
||
from vmplatform import VMOS, Client, VMDataDisk
|
||
|
||
from schemas.context import ASRContext
|
||
from utils.client_async import ClientAsync
|
||
from utils.evaluator import BaseEvaluator
|
||
from utils.logger import logger
|
||
from utils.service import register_sut
|
||
|
||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||
|
||
DATASET_NUM = os.getenv("DATASET_NUM")
|
||
|
||
# vm榜单参数
|
||
SUT_TYPE = os.getenv("SUT_TYPE", "kubernetes")
|
||
SHARE_SUT = os.getenv("SHARE_SUT", "true") == "true"
|
||
VM_ID = 0
|
||
VM_IP = ""
|
||
do_deploy_chart = True
|
||
VM_CPU = int(os.getenv("VM_CPU", "2"))
|
||
VM_MEM = int(os.getenv("VM_MEM", "4096"))
|
||
MODEL_BASEPATH = os.getenv("MODEL_BASEPATH", "/tmp/customer/leaderboard/pc_asr")
|
||
MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
|
||
SSH_KEY_DIR = os.getenv("SSH_KEY_DIR", "/workspace")
|
||
SSH_PUBLIC_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa.pub")
|
||
SSH_KEY_FILE = os.path.join(SSH_KEY_DIR, "ssh-key-ecdsa")
|
||
|
||
CONNECT_KWARGS = {"key_filename": SSH_KEY_FILE}
|
||
|
||
# 共享sut参数
|
||
JOB_ID = os.getenv("JOB_ID")
|
||
dirname = "/tmp/submit_private/sut_share"
|
||
os.makedirs(dirname, exist_ok=True)
|
||
SUT_SHARE_LOCK = os.path.join(dirname, "lock.lock")
|
||
SUT_SHARE_USE_LOCK = os.path.join(dirname, "use.lock")
|
||
SUT_SHARE_STATUS = os.path.join(dirname, "status.json")
|
||
SUT_SHARE_JOB_STATUS = os.path.join(dirname, f"job_status.{JOB_ID}")
|
||
SUT_SHARE_PUBLIC_FAIL = os.path.join(dirname, "one_job_failed")
|
||
fd_lock = open(SUT_SHARE_USE_LOCK, "a")
|
||
|
||
|
||
def clean_vm_atexit():
|
||
global VM_ID, do_deploy_chart
|
||
if not VM_ID:
|
||
return
|
||
if not do_deploy_chart:
|
||
return
|
||
logger.info("删除vm")
|
||
vmclient = Client()
|
||
err_msg = vmclient.delete_vm(VM_ID)
|
||
if err_msg:
|
||
logger.warning(f"删除vm失败: {err_msg}")
|
||
|
||
|
||
def put_file_to_vm(c: Connection, local_path: str, remote_path: str):
|
||
logger.info(f"uploading file {local_path} to {remote_path}")
|
||
result = c.put(local_path, remote_path)
|
||
logger.info("uploaded {0.local} to {0.remote}".format(result))
|
||
|
||
|
||
def deploy_windows_sut():
|
||
global VM_ID
|
||
global VM_IP
|
||
|
||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||
with open(submit_config_filepath, "r") as fp:
|
||
st_config = yaml.safe_load(fp)
|
||
assert "model" in st_config, "未配置model"
|
||
assert "model_key" in st_config, "未配置model_key"
|
||
assert "config.json" in st_config, "未配置config.json"
|
||
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||
assert len(nfs) > 0, "未配置nfs"
|
||
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||
|
||
model = st_config["model"]
|
||
model_key = st_config["model_key"]
|
||
model_path = ""
|
||
config = st_config["config.json"]
|
||
exist = False
|
||
for nfs_item in nfs:
|
||
if nfs_item["name"] == model_key:
|
||
exist = True
|
||
if nfs_item["source"] == "ceph_customer":
|
||
model_path = os.path.join(
|
||
"/tmp/customer",
|
||
nfs_item["srcRelativePath"],
|
||
)
|
||
else:
|
||
model_path = os.path.join(
|
||
"/tmp/juicefs",
|
||
nfs_item["srcRelativePath"],
|
||
)
|
||
break
|
||
if not exist:
|
||
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||
model_dir = os.path.basename(model_path).split(".")[0]
|
||
config["model_path"] = f"E:\\model\\{model_dir}"
|
||
with open(config_path, "w") as fp:
|
||
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||
|
||
vmclient = Client()
|
||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||
sshpublickey = fp.read().rstrip()
|
||
VM_ID = vmclient.create_vm(
|
||
"amd64",
|
||
VMOS.windows10,
|
||
VM_CPU,
|
||
VM_MEM,
|
||
"leaderboard-%s-submit-%s-job-%s"
|
||
% (
|
||
os.getenv("BENCHMARK_NAME"),
|
||
os.getenv("SUBMIT_ID"),
|
||
os.getenv("JOB_ID"),
|
||
),
|
||
sshpublickey,
|
||
datadisks=[
|
||
VMDataDisk(
|
||
size=50,
|
||
disk_type="ssd",
|
||
mount_path="/",
|
||
filesystem="NTFS",
|
||
)
|
||
],
|
||
)
|
||
atexit.register(clean_vm_atexit)
|
||
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||
|
||
def sut_startup():
|
||
with Connection(
|
||
VM_IP,
|
||
"administrator",
|
||
connect_kwargs=CONNECT_KWARGS,
|
||
) as c:
|
||
script_path = "E:\\base\\asr\\faster-whisper\\server"
|
||
script_path = "E:\\install\\asr\\sensevoice\\server"
|
||
bat_filepath = f"{script_path}\\start.bat"
|
||
config_filepath = "E:\\submit\\config.json"
|
||
result = c.run("")
|
||
assert result.ok
|
||
c.run(
|
||
f'cd /d {script_path} & set "EDGE_ML_ENV_HOME=E:\\install" & {bat_filepath} {config_filepath}',
|
||
warn=True,
|
||
)
|
||
|
||
with Connection(
|
||
VM_IP,
|
||
"administrator",
|
||
connect_kwargs=CONNECT_KWARGS,
|
||
) as c:
|
||
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||
filename = os.path.basename(model_filepath)
|
||
put_file_to_vm(c, model_filepath, "/E:/")
|
||
|
||
result = c.run("mkdir E:\\base")
|
||
assert result.ok
|
||
result = c.run("mkdir E:\\model")
|
||
assert result.ok
|
||
result = c.run("mkdir E:\\submit")
|
||
assert result.ok
|
||
|
||
result = c.run(
|
||
f"tar zxvf E:\\{filename} -C E:\\base --strip-components 1"
|
||
)
|
||
assert result.ok
|
||
|
||
result = c.run("E:\\base\\setup-win.bat E:\\install")
|
||
assert result.ok
|
||
|
||
put_file_to_vm(c, config_path, "/E:/submit")
|
||
put_file_to_vm(c, model_path, "/E:/model")
|
||
result = c.run(
|
||
f"tar zxvf E:\\model\\{os.path.basename(model_path)} -C E:\\model"
|
||
)
|
||
assert result.ok
|
||
threading.Thread(target=sut_startup, daemon=True).start()
|
||
time.sleep(60)
|
||
|
||
return f"ws://{VM_IP}:{config['port']}"
|
||
|
||
|
||
def deploy_macos_sut():
|
||
global VM_ID
|
||
global VM_IP
|
||
|
||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "")
|
||
with open(submit_config_filepath, "r") as fp:
|
||
st_config = yaml.safe_load(fp)
|
||
assert "model" in st_config, "未配置model"
|
||
assert "model_key" in st_config, "未配置model_key"
|
||
assert "config.json" in st_config, "未配置config.json"
|
||
nfs = st_config.get("leaderboard_options", {}).get("nfs", [])
|
||
assert len(nfs) > 0, "未配置nfs"
|
||
assert st_config["model"] in MODEL_MAPPING, "提交模型不在可用模型范围内"
|
||
|
||
model = st_config["model"]
|
||
model_key = st_config["model_key"]
|
||
model_path = ""
|
||
config = st_config["config.json"]
|
||
exist = False
|
||
for nfs_item in nfs:
|
||
if nfs_item["name"] == model_key:
|
||
exist = True
|
||
if nfs_item["source"] == "ceph_customer":
|
||
model_path = os.path.join(
|
||
"/tmp/customer",
|
||
nfs_item["srcRelativePath"],
|
||
)
|
||
else:
|
||
model_path = os.path.join(
|
||
"/tmp/juicefs",
|
||
nfs_item["srcRelativePath"],
|
||
)
|
||
break
|
||
if not exist:
|
||
raise RuntimeError(f"未找到nfs配置项 name={model_key}")
|
||
config_path = os.path.join(tempfile.mkdtemp(), "config.json")
|
||
model_dir = os.path.basename(model_path).split(".")[0]
|
||
|
||
vmclient = Client()
|
||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||
sshpublickey = fp.read().rstrip()
|
||
VM_ID = vmclient.create_vm(
|
||
"amd64",
|
||
VMOS.macos12,
|
||
VM_CPU,
|
||
VM_MEM,
|
||
"leaderboard-%s-submit-%s-job-%s"
|
||
% (
|
||
os.getenv("BENCHMARK_NAME"),
|
||
os.getenv("SUBMIT_ID"),
|
||
os.getenv("JOB_ID"),
|
||
),
|
||
sshpublickey,
|
||
datadisks=[
|
||
VMDataDisk(
|
||
size=50,
|
||
disk_type="ssd",
|
||
mount_path="/",
|
||
filesystem="apfs",
|
||
)
|
||
],
|
||
)
|
||
atexit.register(clean_vm_atexit)
|
||
signal.signal(signal.SIGTERM, lambda signum, _: sys.exit(signum))
|
||
VM_IP = vmclient.wait_until_vm_running(VM_ID)
|
||
logger.info("vm created successfully, vm_ip: %s", VM_IP)
|
||
|
||
with Connection(
|
||
VM_IP,
|
||
"admin",
|
||
connect_kwargs=CONNECT_KWARGS,
|
||
) as c:
|
||
result = c.run("ls -d /Volumes/data*")
|
||
assert result.ok
|
||
volume_path = result.stdout.strip()
|
||
|
||
config["model_path"] = f"{volume_path}/model/{model_dir}"
|
||
with open(config_path, "w") as fp:
|
||
json.dump(config, fp, ensure_ascii=False, indent=4)
|
||
|
||
def sut_startup():
|
||
with Connection(
|
||
VM_IP,
|
||
"admin",
|
||
connect_kwargs=CONNECT_KWARGS,
|
||
) as c:
|
||
script_path = f"{volume_path}/install/asr/sensevoice/server"
|
||
startsh = f"{script_path}/start.sh"
|
||
config_filepath = f"{volume_path}/submit/config.json"
|
||
c.run(
|
||
f"cd {script_path} && sh {startsh} {config_filepath}",
|
||
warn=True,
|
||
)
|
||
|
||
with Connection(
|
||
VM_IP,
|
||
"admin",
|
||
connect_kwargs=CONNECT_KWARGS,
|
||
) as c:
|
||
model_filepath = os.path.join(MODEL_BASEPATH, MODEL_MAPPING[model])
|
||
filename = os.path.basename(model_filepath)
|
||
put_file_to_vm(c, model_filepath, f"{volume_path}")
|
||
|
||
result = c.run(f"mkdir {volume_path}/base")
|
||
assert result.ok
|
||
result = c.run(f"mkdir {volume_path}/model")
|
||
assert result.ok
|
||
result = c.run(f"mkdir {volume_path}/submit")
|
||
assert result.ok
|
||
|
||
result = c.run(
|
||
f"tar zxvf {volume_path}/{filename} -C {volume_path}/base --strip-components 1" # noqa: E501
|
||
)
|
||
assert result.ok
|
||
|
||
result = c.run(
|
||
f"sh {volume_path}/base/setup-mac.sh {volume_path}/install x64"
|
||
)
|
||
assert result.ok
|
||
|
||
put_file_to_vm(c, config_path, f"{volume_path}/submit")
|
||
put_file_to_vm(c, model_path, f"{volume_path}/model")
|
||
result = c.run(
|
||
f"tar zxvf {volume_path}/model/{os.path.basename(model_path)} -C {volume_path}/model" # noqa: E501
|
||
)
|
||
assert result.ok
|
||
threading.Thread(target=sut_startup, daemon=True).start()
|
||
time.sleep(60)
|
||
|
||
return f"ws://{VM_IP}:{config['port']}"
|
||
|
||
|
||
def get_sut_url_vm(vm_type: str):
|
||
global VM_ID
|
||
global VM_IP
|
||
global do_deploy_chart
|
||
|
||
do_deploy_chart = True
|
||
# 拉起SUT
|
||
|
||
def check_job_failed():
|
||
while True:
|
||
time.sleep(30)
|
||
if os.path.exists(SUT_SHARE_PUBLIC_FAIL):
|
||
logger.error("there is a job failed in current submit")
|
||
sys.exit(1)
|
||
|
||
sut_url = ""
|
||
threading.Thread(target=check_job_failed, daemon=True).start()
|
||
if SHARE_SUT:
|
||
|
||
time.sleep(10 * random.random())
|
||
try:
|
||
open(SUT_SHARE_LOCK, "x").close()
|
||
except Exception:
|
||
do_deploy_chart = False
|
||
|
||
start_at = time.time()
|
||
|
||
def file_last_updated_at(file: str):
|
||
return os.stat(file).st_mtime if os.path.exists(file) else start_at
|
||
|
||
if not do_deploy_chart:
|
||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||
f.write("waiting")
|
||
while (
|
||
time.time() - file_last_updated_at(SUT_SHARE_STATUS)
|
||
<= 60 * 60 * 24
|
||
):
|
||
logger.info(
|
||
"Waiting sut application to be deployed by another job"
|
||
)
|
||
time.sleep(10 + random.random())
|
||
if os.path.exists(SUT_SHARE_STATUS):
|
||
get_status = False
|
||
for _ in range(10):
|
||
try:
|
||
with open(SUT_SHARE_STATUS, "r") as f:
|
||
status = json.load(f)
|
||
get_status = True
|
||
break
|
||
except Exception:
|
||
time.sleep(1 + random.random())
|
||
continue
|
||
if not get_status:
|
||
raise RuntimeError(
|
||
"Failed to get status of sut application"
|
||
)
|
||
assert (
|
||
status.get("status") != "failed"
|
||
), "Failed to deploy sut application, \
|
||
please check other job logs"
|
||
if status.get("status") == "running":
|
||
VM_ID = status.get("vmid")
|
||
VM_IP = status.get("vmip")
|
||
sut_url = status.get("sut_url")
|
||
with open(SSH_PUBLIC_KEY_FILE, "w") as fp:
|
||
fp.write(status.get("pubkey"))
|
||
with open(SSH_KEY_FILE, "w") as fp:
|
||
fp.write(status.get("prikey"))
|
||
logger.info("Successfully get deployed sut application")
|
||
break
|
||
|
||
if do_deploy_chart:
|
||
try:
|
||
fcntl.flock(fd_lock, fcntl.LOCK_EX)
|
||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||
f.write("waiting")
|
||
pending = True
|
||
|
||
def update_status():
|
||
while pending:
|
||
time.sleep(30)
|
||
if not pending:
|
||
break
|
||
with open(SUT_SHARE_STATUS, "w") as f:
|
||
json.dump({"status": "pending"}, f)
|
||
|
||
threading.Thread(target=update_status, daemon=True).start()
|
||
if vm_type == "windows":
|
||
sut_url = deploy_windows_sut()
|
||
else:
|
||
sut_url = deploy_macos_sut()
|
||
except Exception:
|
||
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||
with open(SUT_SHARE_STATUS, "w") as f:
|
||
json.dump({"status": "failed"}, f)
|
||
raise
|
||
finally:
|
||
pending = False
|
||
with open(SUT_SHARE_STATUS, "w") as f:
|
||
pubkey = ""
|
||
with open(SSH_PUBLIC_KEY_FILE, "r") as fp:
|
||
pubkey = fp.read().rstrip()
|
||
prikey = ""
|
||
with open(SSH_KEY_FILE, "r") as fp:
|
||
prikey = fp.read()
|
||
json.dump(
|
||
{
|
||
"status": "running",
|
||
"vmid": VM_ID,
|
||
"vmip": VM_IP,
|
||
"pubkey": pubkey,
|
||
"sut_url": sut_url,
|
||
"prikey": prikey,
|
||
},
|
||
f,
|
||
)
|
||
else:
|
||
while True:
|
||
time.sleep(5 + random.random())
|
||
try:
|
||
fcntl.flock(fd_lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||
break
|
||
except Exception:
|
||
logger.info("尝试抢占调用sut失败,继续等待 5s ...")
|
||
|
||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||
f.write("running")
|
||
|
||
return sut_url
|
||
|
||
|
||
def get_sut_url():
|
||
if SUT_TYPE in ("windows", "macos"):
|
||
return get_sut_url_vm(SUT_TYPE)
|
||
|
||
submit_config_filepath = os.getenv(
|
||
"SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config"
|
||
)
|
||
CPU = os.getenv("SUT_CPU", "2")
|
||
MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||
resource_name = os.getenv("BENCHMARK_NAME")
|
||
|
||
# 任务信息
|
||
# 斯拉夫语族:俄语、波兰语
|
||
# 日耳曼语族:英语、德语、荷兰语
|
||
# 拉丁语族(罗曼语族):西班牙语、葡萄牙语、法国语、意大利语
|
||
# 闪米特语族:阿拉伯语、希伯来语
|
||
|
||
# 提交配置 & 启动被测服务
|
||
if os.getenv("DATASET_FILEPATH", ""):
|
||
with open(submit_config_filepath, "r") as fp:
|
||
st_config = yaml.safe_load(fp)
|
||
if "values" not in st_config:
|
||
st_config["values"] = {}
|
||
st_config["values"]["resources"] = {}
|
||
st_config["values"]["resources"]["limits"] = {}
|
||
st_config["values"]["resources"]["limits"]["cpu"] = CPU
|
||
st_config["values"]["resources"]["limits"]["memory"] = MEMORY
|
||
# st_config["values"]['resources']['limits']['nvidia.com/gpu'] = '1'
|
||
# st_config["values"]['resources']['limits']['nvidia.com/gpumem'] = "1843"
|
||
# st_config["values"]['resources']['limits']['nvidia.com/gpucores'] = "8"
|
||
st_config["values"]["resources"]["requests"] = {}
|
||
st_config["values"]["resources"]["requests"]["cpu"] = CPU
|
||
st_config["values"]["resources"]["requests"]["memory"] = MEMORY
|
||
# st_config["values"]['resources']['requests']['nvidia.com/gpu'] = '1'
|
||
# st_config["values"]['resources']['requests']['nvidia.com/gpumem'] = "1843"
|
||
# st_config["values"]['resources']['requests']['nvidia.com/gpucores'] = "8"
|
||
# st_config['values']['nodeSelector'] = {}
|
||
# st_config["values"]["nodeSelector"][
|
||
# "contest.4pd.io/accelerator"
|
||
# ] = "A10vgpu"
|
||
# st_config['values']['tolerations'] = []
|
||
# toleration_item = {}
|
||
# toleration_item['key'] = 'hosttype'
|
||
# toleration_item['operator'] = 'Equal'
|
||
# toleration_item['value'] = 'vgpu'
|
||
# toleration_item['effect'] = 'NoSchedule'
|
||
# st_config['values']['tolerations'].append(toleration_item)
|
||
if os.getenv("RESOURCE_TYPE", "cpu") == "cpu":
|
||
values = st_config["values"]
|
||
limits = values.get("resources", {}).get("limits", {})
|
||
requests = values.get("resources", {}).get("requests", {})
|
||
if (
|
||
"nvidia.com/gpu" in limits
|
||
or "nvidia.com/gpumem" in limits
|
||
or "nvidia.com/gpucores" in limits
|
||
or "nvidia.com/gpu" in requests
|
||
or "nvidia.com/gpumem" in requests
|
||
or "nvidia.com/gpucores" in requests
|
||
):
|
||
raise Exception("禁止使用GPU!")
|
||
else:
|
||
vgpu_num = int(os.getenv("SUT_VGPU", "3"))
|
||
st_config["values"]["resources"]["limits"]["nvidia.com/gpu"] = (
|
||
str(vgpu_num)
|
||
)
|
||
st_config["values"]["resources"]["limits"][
|
||
"nvidia.com/gpumem"
|
||
] = str(1843 * vgpu_num)
|
||
st_config["values"]["resources"]["limits"][
|
||
"nvidia.com/gpucores"
|
||
] = str(8 * vgpu_num)
|
||
st_config["values"]["resources"]["requests"][
|
||
"nvidia.com/gpu"
|
||
] = str(vgpu_num)
|
||
st_config["values"]["resources"]["requests"][
|
||
"nvidia.com/gpumem"
|
||
] = str(1843 * vgpu_num)
|
||
st_config["values"]["resources"]["requests"][
|
||
"nvidia.com/gpucores"
|
||
] = str(8 * vgpu_num)
|
||
st_config["values"]["nodeSelector"] = {}
|
||
st_config["values"]["nodeSelector"][
|
||
"contest.4pd.io/accelerator"
|
||
] = "A10vgpu"
|
||
st_config["values"]["tolerations"] = []
|
||
toleration_item = {}
|
||
toleration_item["key"] = "hosttype"
|
||
toleration_item["operator"] = "Equal"
|
||
toleration_item["value"] = "vgpu"
|
||
toleration_item["effect"] = "NoSchedule"
|
||
st_config["values"]["tolerations"].append(toleration_item)
|
||
if "docker_images" in st_config:
|
||
sut_url = "ws://172.26.1.75:9827"
|
||
os.environ["test"] = "1"
|
||
elif "docker_image" in st_config:
|
||
sut_url = register_sut(st_config, resource_name)
|
||
elif UNIT_TEST:
|
||
sut_url = "ws://172.27.231.36:80"
|
||
else:
|
||
logger.error("config 配置错误,没有 docker_image")
|
||
os._exit(1)
|
||
return sut_url
|
||
else:
|
||
os.environ["test"] = "1"
|
||
sut_url = "ws://172.27.231.36:80"
|
||
sut_url = "ws://172.26.1.75:9827"
|
||
return sut_url
|
||
|
||
|
||
def load_merge_dataset(dataset_filepath: str) -> dict:
|
||
local_dataset_path = "./dataset"
|
||
os.makedirs(local_dataset_path, exist_ok=True)
|
||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||
zf.extractall(local_dataset_path)
|
||
|
||
config = {}
|
||
sub_datasets = os.listdir(local_dataset_path)
|
||
for sub_dataset in sub_datasets:
|
||
if sub_dataset.startswith("asr."):
|
||
lang = sub_dataset[4:]
|
||
lang_path = os.path.join(local_dataset_path, lang)
|
||
os.makedirs(lang_path, exist_ok=True)
|
||
with zipfile.ZipFile(
|
||
os.path.join(local_dataset_path, sub_dataset)
|
||
) as zf:
|
||
zf.extractall(lang_path)
|
||
lang_config_path = os.path.join(lang_path, "data.yaml")
|
||
with open(lang_config_path, "r") as fp:
|
||
lang_config = yaml.safe_load(fp)
|
||
audio_lengths = {}
|
||
for query_item in lang_config.get("query_data", []):
|
||
audio_path = os.path.join(
|
||
lang_path,
|
||
query_item["file"],
|
||
)
|
||
query_item["file"] = audio_path
|
||
audio_lengths[query_item["file"]] = os.path.getsize(
|
||
audio_path,
|
||
)
|
||
lang_config["query_data"] = sorted(
|
||
lang_config.get("query_data", []),
|
||
key=lambda x: audio_lengths[x["file"]],
|
||
reverse=True,
|
||
)
|
||
|
||
idx = 0
|
||
length = 0.0
|
||
for query_item in lang_config["query_data"]:
|
||
audio_length = audio_lengths[query_item["file"]]
|
||
length += audio_length / 32000
|
||
idx += 1
|
||
# 每个语言限制半个小时长度
|
||
if length >= 30 * 60:
|
||
break
|
||
|
||
lang_config["query_data"] = lang_config["query_data"][:idx]
|
||
config[lang] = lang_config
|
||
|
||
config["query_data"] = []
|
||
for lang, lang_config in config.items():
|
||
if lang == "query_data":
|
||
continue
|
||
for query_item in lang_config["query_data"]:
|
||
config["query_data"].append(
|
||
{
|
||
**query_item,
|
||
"lang": lang,
|
||
}
|
||
)
|
||
random.Random(0).shuffle(config["query_data"])
|
||
|
||
return config
|
||
|
||
|
||
def postprocess_failed():
|
||
open(SUT_SHARE_PUBLIC_FAIL, "w").close()
|
||
|
||
|
||
def main():
|
||
dataset_filepath = os.getenv(
|
||
"DATASET_FILEPATH",
|
||
"/Users/4paradigm/Projects/dataset/asr/de.zip",
|
||
# "./tests/resources/en.zip",
|
||
)
|
||
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||
detail_cases_filepath = os.getenv(
|
||
"DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl"
|
||
)
|
||
thread_num = int(os.getenv("THREAD_NUM", "1"))
|
||
|
||
# 数据集处理
|
||
config = {}
|
||
if os.getenv("MERGE_DATASET", "1"):
|
||
config = load_merge_dataset(dataset_filepath)
|
||
dataset_query = config["query_data"]
|
||
else:
|
||
local_dataset_path = "./dataset"
|
||
os.makedirs(local_dataset_path, exist_ok=True)
|
||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||
zf.extractall(local_dataset_path)
|
||
config_path = os.path.join(local_dataset_path, "data.yaml")
|
||
with open(config_path, "r") as fp:
|
||
dataset_config = yaml.safe_load(fp)
|
||
# 读取所有的音频,进而获得音频的总长度,最后按照音频长度对 query_data 进行降序排序
|
||
lang = os.getenv("lang")
|
||
if lang is None:
|
||
lang = dataset_config.get("global", {}).get("lang", "en")
|
||
audio_lengths = []
|
||
for query_item in dataset_config.get("query_data", []):
|
||
query_item["lang"] = lang
|
||
audio_path = os.path.join(local_dataset_path, query_item["file"])
|
||
query_item["file"] = audio_path
|
||
audio_lengths.append(os.path.getsize(audio_path) / 1024 / 1024)
|
||
dataset_config["query_data"] = sorted(
|
||
dataset_config.get("query_data", []),
|
||
key=lambda x: audio_lengths[dataset_config["query_data"].index(x)],
|
||
reverse=True,
|
||
)
|
||
# 数据集信息
|
||
# dataset_global_config = dataset_config.get("global", {})
|
||
dataset_query = dataset_config.get("query_data", {})
|
||
config[lang] = dataset_config
|
||
|
||
# sut url
|
||
sut_url = get_sut_url()
|
||
|
||
try:
|
||
# 开始测试
|
||
logger.info("开始执行")
|
||
evaluator = BaseEvaluator()
|
||
future_list = []
|
||
with ThreadPoolExecutor(max_workers=thread_num) as executor:
|
||
for idx, query_item in enumerate(dataset_query):
|
||
context = ASRContext(
|
||
**config[query_item["lang"]].get("global", {}),
|
||
)
|
||
context.lang = query_item["lang"]
|
||
context.file_path = query_item["file"]
|
||
context.append_labels(query_item["voice"])
|
||
future = executor.submit(
|
||
ClientAsync(sut_url, context, idx).action
|
||
)
|
||
future_list.append(future)
|
||
for future in concurrent.futures.as_completed(future_list):
|
||
context = future.result()
|
||
evaluator.evaluate(context)
|
||
detail_case = evaluator.gen_detail_case()
|
||
with open(detail_cases_filepath, "a") as fp:
|
||
fp.write(
|
||
json.dumps(
|
||
detail_case.to_dict(),
|
||
ensure_ascii=False,
|
||
)
|
||
+ "\n",
|
||
)
|
||
del context
|
||
gc.collect()
|
||
|
||
evaluator.post_evaluate()
|
||
output_result = evaluator.gen_result()
|
||
logger.info("执行完成")
|
||
|
||
with open(result_filepath, "w") as fp:
|
||
json.dump(output_result, fp, indent=2, ensure_ascii=False)
|
||
with open(bad_cases_filepath, "w") as fp:
|
||
fp.write("当前榜单不存在 Bad Case\n")
|
||
|
||
if SHARE_SUT:
|
||
with open(SUT_SHARE_JOB_STATUS, "w") as f:
|
||
f.write("success")
|
||
|
||
fcntl.flock(fd_lock, fcntl.LOCK_UN)
|
||
fd_lock.close()
|
||
while SHARE_SUT and do_deploy_chart:
|
||
time.sleep(30)
|
||
success_num = 0
|
||
for job_status_file in glob.glob(dirname + "/job_status.*"):
|
||
with open(job_status_file, "r") as f:
|
||
job_status = f.read()
|
||
success_num += job_status == "success"
|
||
if success_num == int(DATASET_NUM):
|
||
break
|
||
logger.info("Waiting for all jobs to complete")
|
||
except Exception:
|
||
if SHARE_SUT:
|
||
postprocess_failed()
|
||
raise
|
||
sys.exit(0)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|