1194 lines
39 KiB
Python
1194 lines
39 KiB
Python
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import tempfile
|
||
import zipfile
|
||
import threading
|
||
from collections import defaultdict
|
||
from typing import Dict, List
|
||
|
||
import yaml
|
||
from pydantic import ValidationError
|
||
|
||
from schemas.dataset import QueryData
|
||
from utils.client_callback import ClientCallback, EvaluateResult, StopException
|
||
from utils.logger import log
|
||
from utils.service import register_sut
|
||
from utils.update_submit import change_product_available
|
||
from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml
|
||
from utils.leaderboard import change_product_unavailable
|
||
|
||
|
||
lck = threading.Lock()
|
||
|
||
# Environment variables by leaderboard
|
||
DATASET_FILEPATH = os.environ["DATASET_FILEPATH"]
|
||
RESULT_FILEPATH = os.environ["RESULT_FILEPATH"]
|
||
|
||
DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"]
|
||
SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"]
|
||
BENCHMARK_NAME = os.environ["BENCHMARK_NAME"]
|
||
TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1))
|
||
THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8))
|
||
|
||
log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}")
|
||
workspace_path = "/tmp/workspace"
|
||
|
||
|
||
# Environment variables by kubernetes
|
||
MY_POD_IP = os.environ["MY_POD_IP"]
|
||
|
||
# constants
|
||
RESOURCE_NAME = BENCHMARK_NAME
|
||
|
||
# Environment variables by judge_flow_config
|
||
LANG = os.getenv("lang")
|
||
SUT_CPU = os.getenv("SUT_CPU", "2")
|
||
SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||
SUT_VGPU = os.getenv("SUT_VGPU", "1")
|
||
#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU)))
|
||
#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU)))
|
||
SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100")
|
||
RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu")
|
||
assert RESOURCE_TYPE in [
|
||
"cpu",
|
||
"vgpu",
|
||
], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu"
|
||
|
||
|
||
unzip_dir(DATASET_FILEPATH, workspace_path)
|
||
|
||
def get_sut_url_kubernetes():
|
||
with open(SUBMIT_CONFIG_FILEPATH, "r") as f:
|
||
submit_config = yaml.safe_load(f)
|
||
assert isinstance(submit_config, dict)
|
||
|
||
submit_config.setdefault("values", {})
|
||
|
||
submit_config["values"]["containers"] = [
|
||
{
|
||
"name": "corex-container",
|
||
"image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像
|
||
"command": ["sleep"], # 替换为你的模型启动命令,使用python解释器
|
||
"args": ["3600"], # 替换为你的模型参数,运行我的推理脚本
|
||
|
||
# 添加存储卷挂载
|
||
#"volumeMounts": [
|
||
# {
|
||
# "name": "model-volume",
|
||
# "mountPath": "/model" # 挂载到/model目录
|
||
# }
|
||
#]
|
||
}
|
||
]
|
||
|
||
"""
|
||
# 添加存储卷配置
|
||
submit_config["values"]["volumes"] = [
|
||
{
|
||
"name": "model-volume",
|
||
"persistentVolumeClaim": {
|
||
"claimName": "sid-model-pvc" # 使用已有的PVC
|
||
}
|
||
}
|
||
]
|
||
"""
|
||
|
||
"""
|
||
# Inject specified cpu and memory
|
||
resource = {
|
||
"cpu": SUT_CPU,
|
||
"memory": SUT_MEMORY,
|
||
}
|
||
"""
|
||
submit_config["values"]["resources"] = {
|
||
"requests":{},
|
||
"limits": {},
|
||
}
|
||
|
||
limits = submit_config["values"]["resources"]["limits"]
|
||
requests = submit_config["values"]["resources"]["requests"]
|
||
|
||
|
||
"""
|
||
# ########## 关键修改:替换为iluvatar GPU配置 ##########
|
||
if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU
|
||
# 替换nvidia资源键为iluvatar.ai/gpu
|
||
vgpu_resource = {
|
||
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
|
||
# 若需要其他资源(如显存),按你的K8s配置补充,例如:
|
||
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
|
||
}
|
||
limits.update(vgpu_resource)
|
||
requests.update(vgpu_resource)
|
||
# 节点选择器:替换为你的accelerator标签
|
||
submit_config["values"]["nodeSelector"] = {
|
||
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
|
||
}
|
||
# 容忍度:替换为你的tolerations配置
|
||
submit_config["values"]["tolerations"] = [
|
||
{
|
||
"key": "hosttype",
|
||
"operator": "Equal",
|
||
"value": "iluvatar",
|
||
"effect": "NoSchedule",
|
||
}
|
||
]
|
||
# #########################################
|
||
# 禁止CPU模式下使用GPU资源(保持原逻辑)
|
||
else:
|
||
if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests:
|
||
log.error("禁止在CPU模式下使用GPU资源")
|
||
sys.exit(1)
|
||
|
||
|
||
|
||
#gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键
|
||
#for key in gpukeys:
|
||
# if key in limits or key in requests:
|
||
# log.error("禁止使用vgpu资源")
|
||
# sys.exit(1)
|
||
|
||
"""
|
||
|
||
# 替换nvidia资源键为iluvatar.ai/gpu
|
||
vgpu_resource = {
|
||
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
|
||
# 若需要其他资源(如显存),按你的K8s配置补充,例如:
|
||
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
|
||
}
|
||
limits.update(vgpu_resource)
|
||
requests.update(vgpu_resource)
|
||
# 节点选择器:替换为你的accelerator标签
|
||
submit_config["values"]["nodeSelector"] = {
|
||
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
|
||
}
|
||
# 容忍度:替换为你的tolerations配置
|
||
"""
|
||
submit_config["values"]["tolerations"] = [
|
||
{
|
||
"key": "hosttype",
|
||
"operator": "Equal",
|
||
"value": "iluvatar",
|
||
"effect": "NoSchedule",
|
||
},
|
||
{
|
||
"key": "hosttype",
|
||
"operator": "Equal",
|
||
"value": "arm64",
|
||
"effect": "NoSchedule",
|
||
},
|
||
{
|
||
"key": "hosttype",
|
||
"operator": "Equal",
|
||
"value": "myinit",
|
||
"effect": "NoSchedule",
|
||
},
|
||
{
|
||
"key": "hosttype",
|
||
"operator": "Equal",
|
||
"value": "middleware",
|
||
"effect": "NoSchedule",
|
||
}
|
||
|
||
]
|
||
"""
|
||
"""
|
||
{
|
||
"key": "node-role.kubernetes.io/master",
|
||
"operator": "Exists",
|
||
"effect": "NoSchedule",
|
||
},
|
||
{
|
||
"key": "node.kubernetes.io/not-ready",
|
||
"operator": "Exists",
|
||
"effect": "NoExecute",
|
||
"tolerationSeconds": 300
|
||
},
|
||
{
|
||
"key": "node.kubernetes.io/unreachable",
|
||
"operator": "Exists",
|
||
"effect": "NoExecute",
|
||
"tolerationSeconds": 300
|
||
}
|
||
"""
|
||
|
||
|
||
log.info(f"submit_config: {submit_config}")
|
||
log.info(f"RESOURCE_NAME: {RESOURCE_NAME}")
|
||
|
||
return register_sut(submit_config, RESOURCE_NAME).replace(
|
||
"ws://", "http://"
|
||
)
|
||
|
||
|
||
def get_sut_url():
|
||
return get_sut_url_kubernetes()
|
||
|
||
#SUT_URL = get_sut_url()
|
||
#os.environ["SUT_URL"] = SUT_URL
|
||
|
||
"""
|
||
def load_dataset(
|
||
dataset_filepath: str,
|
||
) -> Dict[str, List[QueryData]]:
|
||
dataset_path = tempfile.mkdtemp()
|
||
|
||
with zipfile.ZipFile(dataset_filepath) as zf:
|
||
zf.extractall(dataset_path)
|
||
|
||
basename = os.path.basename(dataset_filepath)
|
||
datayaml = os.path.join(dataset_path, "data.yaml")
|
||
if not os.path.exists(datayaml):
|
||
sub_dataset_paths = os.listdir(dataset_path)
|
||
dataset = {}
|
||
for sub_dataset_path in sub_dataset_paths:
|
||
sub_dataset = load_dataset(
|
||
os.path.join(dataset_path, sub_dataset_path)
|
||
)
|
||
for k, v in sub_dataset.items():
|
||
k = os.path.join(basename, k)
|
||
dataset[k] = v
|
||
return dataset
|
||
|
||
with open(datayaml, "r") as f:
|
||
data = yaml.safe_load(f)
|
||
assert isinstance(data, dict)
|
||
|
||
lang = LANG
|
||
data_lang = data.get("global", {}).get("lang")
|
||
if lang is None and data_lang is not None:
|
||
if data_lang is not None:
|
||
# 使用配置中的语言类型
|
||
lang = data_lang
|
||
if lang is None and basename.startswith("asr.") and len(basename) == 4 + 2:
|
||
# 数据集名称为asr.en 可以认为语言为en
|
||
lang = basename[4:]
|
||
if lang is None:
|
||
log.error(
|
||
"数据集错误 通过data.yaml中的 global.lang 或 数据集名称 asr.xx 指定语言类型"
|
||
)
|
||
sys.exit(1)
|
||
|
||
query_data = data.get("query_data", [])
|
||
audio_size_map = {}
|
||
for query in query_data:
|
||
query["lang"] = lang
|
||
query["file"] = os.path.join(dataset_path, query["file"])
|
||
audio_size_map[query["file"]] = os.path.getsize(query["file"])
|
||
# 根据音频大小排序
|
||
query_data = sorted(
|
||
query_data, key=lambda x: audio_size_map[x["file"]], reverse=True
|
||
)
|
||
valid_query_data = []
|
||
for i_query_data in query_data:
|
||
try:
|
||
valid_query_data.append(QueryData.model_validate(i_query_data))
|
||
except ValidationError:
|
||
log.error("数据集错误 数据中query_data格式错误")
|
||
sys.exit(1)
|
||
|
||
return {
|
||
basename: valid_query_data,
|
||
}
|
||
|
||
|
||
def merge_query_data(dataset: Dict[str, List[QueryData]]) -> List[QueryData]:
|
||
query_datas = []
|
||
for query_data in dataset.values():
|
||
query_datas.extend(query_data)
|
||
return query_datas
|
||
|
||
|
||
def run_one_predict(
|
||
client: ClientCallback, query_data: QueryData, task_id: str
|
||
) -> EvaluateResult:
|
||
try:
|
||
client.predict(None, query_data.file, query_data.duration, task_id)
|
||
except StopException:
|
||
sys.exit(1)
|
||
|
||
client.finished.wait()
|
||
|
||
if client.error is not None:
|
||
sys.exit(1)
|
||
|
||
client.app_on = False
|
||
|
||
try:
|
||
with lck:
|
||
ret = client.evaluate(query_data)
|
||
return ret
|
||
except StopException:
|
||
sys.exit(1)
|
||
"""
|
||
|
||
"""
|
||
def predict_task(
|
||
client: ClientCallback, task_id: int, query_data: QueryData, test_results: list
|
||
):
|
||
log.info(f"Task-{task_id}开始评测")
|
||
test_results[task_id] = run_one_predict(client, query_data, str(task_id))
|
||
|
||
|
||
def merge_concurrent_result(evaluate_results: List[EvaluateResult]) -> Dict:
|
||
cer = 0.0
|
||
align_start = {}
|
||
align_end = {}
|
||
first_word_distance_sum = 0.0
|
||
last_word_distance_sum = 0.0
|
||
rtf = 0.0
|
||
first_receive_delay: float = 0.0
|
||
query_count: int = 0
|
||
voice_count: int = 0
|
||
pred_punctuation_num: int = 0
|
||
label_punctuation_num: int = 0
|
||
pred_sentence_punctuation_num: int = 0
|
||
label_setence_punctuation_num: int = 0
|
||
|
||
for evalute_result in evaluate_results:
|
||
cer += evalute_result.cer
|
||
for k, v in evalute_result.align_start.items():
|
||
align_start.setdefault(k, 0)
|
||
align_start[k] += v
|
||
for k, v in evalute_result.align_end.items():
|
||
align_end.setdefault(k, 0)
|
||
align_end[k] += v
|
||
first_word_distance_sum += evalute_result.first_word_distance_sum
|
||
last_word_distance_sum += evalute_result.last_word_distance_sum
|
||
rtf += evalute_result.rtf
|
||
first_receive_delay += evalute_result.first_receive_delay
|
||
query_count += evalute_result.query_count
|
||
voice_count += evalute_result.voice_count
|
||
pred_punctuation_num += evalute_result.pred_punctuation_num
|
||
label_punctuation_num += evalute_result.label_punctuation_num
|
||
pred_sentence_punctuation_num += (
|
||
evalute_result.pred_sentence_punctuation_num
|
||
)
|
||
label_setence_punctuation_num += (
|
||
evalute_result.label_setence_punctuation_num
|
||
)
|
||
lens = len(evaluate_results)
|
||
cer /= lens
|
||
for k, v in align_start.items():
|
||
align_start[k] /= voice_count
|
||
for k, v in align_end.items():
|
||
align_end[k] /= voice_count
|
||
first_word_distance = first_word_distance_sum / voice_count
|
||
last_word_distance = last_word_distance_sum / voice_count
|
||
rtf /= lens
|
||
first_receive_delay /= lens
|
||
json_result = {
|
||
"one_minus_cer": 1 - cer,
|
||
"first_word_distance_mean": first_word_distance,
|
||
"last_word_distance_mean": last_word_distance,
|
||
"query_count": query_count // lens,
|
||
"voice_count": voice_count // lens,
|
||
"rtf": rtf,
|
||
"first_receive_delay": first_receive_delay,
|
||
"punctuation_ratio": (
|
||
pred_punctuation_num / label_punctuation_num
|
||
if label_punctuation_num > 0
|
||
else 1.0
|
||
),
|
||
"sentence_punctuation_ratio": (
|
||
pred_sentence_punctuation_num / label_setence_punctuation_num
|
||
if label_setence_punctuation_num > 0
|
||
else 1.0
|
||
),
|
||
}
|
||
for k, v in align_start.items():
|
||
json_result["start_word_%dms_ratio" % k] = v
|
||
for k, v in align_end.items():
|
||
json_result["end_word_%dms_ratio" % k] = v
|
||
|
||
return json_result
|
||
|
||
|
||
def merge_result(result: Dict[str, List[EvaluateResult]]) -> Dict:
|
||
json_result = {}
|
||
for lang, evaluate_results in result.items():
|
||
if len(evaluate_results) == 0:
|
||
continue
|
||
cer = 0.0
|
||
align_start = {}
|
||
align_end = {}
|
||
first_word_distance_sum = 0.0
|
||
last_word_distance_sum = 0.0
|
||
rtf = 0.0
|
||
first_receive_delay: float = 0.0
|
||
query_count: int = 0
|
||
voice_count: int = 0
|
||
pred_punctuation_num: int = 0
|
||
label_punctuation_num: int = 0
|
||
pred_sentence_punctuation_num: int = 0
|
||
label_setence_punctuation_num: int = 0
|
||
for evalute_result in evaluate_results:
|
||
cer += evalute_result.cer
|
||
for k, v in evalute_result.align_start.items():
|
||
align_start.setdefault(k, 0)
|
||
align_start[k] += v
|
||
for k, v in evalute_result.align_end.items():
|
||
align_end.setdefault(k, 0)
|
||
align_end[k] += v
|
||
first_word_distance_sum += evalute_result.first_word_distance_sum
|
||
last_word_distance_sum += evalute_result.last_word_distance_sum
|
||
rtf += evalute_result.rtf
|
||
first_receive_delay += evalute_result.first_receive_delay
|
||
query_count += evalute_result.query_count
|
||
voice_count += evalute_result.voice_count
|
||
pred_punctuation_num += evalute_result.pred_punctuation_num
|
||
label_punctuation_num += evalute_result.label_punctuation_num
|
||
pred_sentence_punctuation_num += (
|
||
evalute_result.pred_sentence_punctuation_num
|
||
)
|
||
label_setence_punctuation_num += (
|
||
evalute_result.label_setence_punctuation_num
|
||
)
|
||
lens = len(evaluate_results)
|
||
cer /= lens
|
||
for k, v in align_start.items():
|
||
align_start[k] /= voice_count
|
||
for k, v in align_end.items():
|
||
align_end[k] /= voice_count
|
||
first_word_distance = first_word_distance_sum / voice_count
|
||
last_word_distance = last_word_distance_sum / voice_count
|
||
rtf /= lens
|
||
first_receive_delay /= lens
|
||
lang_result = {
|
||
"one_minus_cer": 1 - cer,
|
||
"first_word_distance_mean": first_word_distance,
|
||
"last_word_distance_mean": last_word_distance,
|
||
"query_count": 1,
|
||
"voice_count": voice_count,
|
||
"rtf": rtf,
|
||
"first_receive_delay": first_receive_delay,
|
||
"punctuation_ratio": (
|
||
pred_punctuation_num / label_punctuation_num
|
||
if label_punctuation_num > 0
|
||
else 1.0
|
||
),
|
||
"sentence_punctuation_ratio": (
|
||
pred_sentence_punctuation_num / label_setence_punctuation_num
|
||
if label_setence_punctuation_num > 0
|
||
else 1.0
|
||
),
|
||
}
|
||
for k, v in align_start.items():
|
||
lang_result["start_word_%dms_ratio" % k] = v
|
||
for k, v in align_end.items():
|
||
lang_result["end_word_%dms_ratio" % k] = v
|
||
if lang == "":
|
||
json_result.update(lang_result)
|
||
else:
|
||
json_result[lang] = lang_result
|
||
return json_result
|
||
"""
|
||
|
||
"""
|
||
def main():
|
||
log.info(f'{TEST_CONCURRENCY=}, {THRESHOLD_OMCER=}')
|
||
dataset = load_dataset(DATASET_FILEPATH)
|
||
query_datas = merge_query_data(dataset)
|
||
|
||
#获取 ASR 服务 URL(通常从 Kubernetes 配置)
|
||
sut_url = get_sut_url()
|
||
|
||
#创建多个客户端实例(每个客户端监听不同端口,如 80、81、82...)
|
||
port_base = 80
|
||
clients = [ClientCallback(sut_url, port_base + i) for i in range(TEST_CONCURRENCY)]
|
||
|
||
#准备测试数据与线程
|
||
detail_cases = []
|
||
# we use the same test data for all requests
|
||
query_data = query_datas[0]
|
||
|
||
test_results = [None] * len(clients)
|
||
test_threads = [threading.Thread(target=predict_task, args=(client, task_id, query_data, test_results))
|
||
for task_id, client in enumerate(clients)]
|
||
|
||
#启动并发测试,启动线程并间隔10秒,设置超时时间为1小时
|
||
for t in test_threads:
|
||
t.start()
|
||
time.sleep(10)
|
||
[t.join(timeout=3600) for t in test_threads]
|
||
|
||
#合并结果与评估
|
||
final_result = merge_concurrent_result(test_results)
|
||
product_avaiable = all([c.product_avaiable for c in clients])
|
||
|
||
final_result['concurrent_req'] = TEST_CONCURRENCY
|
||
if final_result['one_minus_cer'] < THRESHOLD_OMCER:
|
||
product_avaiable = False
|
||
|
||
if not product_avaiable:
|
||
final_result['success'] = False
|
||
change_product_available()
|
||
else:
|
||
final_result['success'] = True
|
||
|
||
#保存结果,
|
||
log.info(
|
||
"指标结果为: %s", json.dumps(final_result, indent=2, ensure_ascii=False)
|
||
)
|
||
|
||
time.sleep(120)
|
||
#打印并保存最终结果到文件
|
||
with open(RESULT_FILEPATH, "w") as f:
|
||
json.dump(final_result, f, indent=2, ensure_ascii=False)
|
||
#保存详细测试用例结果
|
||
with open(DETAILED_CASES_FILEPATH, "w") as f:
|
||
json.dump(detail_cases, f, indent=2, ensure_ascii=False)
|
||
"""
|
||
|
||
#############################################################################
|
||
|
||
import requests
|
||
import base64
|
||
|
||
def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||
"""
|
||
生成请求的body
|
||
:param apiname
|
||
:param APPId: Appid
|
||
:param file_name: 文件路径
|
||
:return:
|
||
"""
|
||
if apiname == 'createFeature':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "createFeature",
|
||
"groupId": "test_voiceprint_e",
|
||
"featureId": featureId,
|
||
"featureInfo": featureInfo,
|
||
"createFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'createGroup':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "createGroup",
|
||
"groupId": "test_voiceprint_e",
|
||
"groupName": "vip_user",
|
||
"groupInfo": "store_vip_user_voiceprint",
|
||
"createGroupRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'deleteFeature':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "deleteFeature",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"featureId": "iFLYTEK_examples_featureId",
|
||
"deleteFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'queryFeatureList':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "queryFeatureList",
|
||
"groupId": "user_voiceprint_2",
|
||
"queryFeatureListRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'searchFea':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "searchFea",
|
||
"groupId": "test_voiceprint_e",
|
||
"topK": 1,
|
||
"searchFeaRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'searchScoreFea':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "searchScoreFea",
|
||
"groupId": "test_voiceprint_e",
|
||
"dstFeatureId": dstFeatureId,
|
||
"searchScoreFeaRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'updateFeature':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "updateFeature",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"featureId": "iFLYTEK_examples_featureId",
|
||
"featureInfo": "iFLYTEK_examples_featureInfo_update",
|
||
"updateFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'deleteGroup':
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "deleteGroup",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"deleteGroupRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
else:
|
||
raise Exception(
|
||
"输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查")
|
||
return body
|
||
|
||
|
||
|
||
log.info(f"开始请求获取到SUT服务URL")
|
||
# 获取SUT服务URL
|
||
sut_url = get_sut_url()
|
||
print(f"获取到的SUT_URL: {sut_url}") # 调试输出
|
||
log.info(f"获取到SUT服务URL: {sut_url}")
|
||
|
||
from urllib.parse import urlparse
|
||
|
||
# 全局变量
|
||
text_decoded = None
|
||
|
||
###################################新增新增################################
|
||
def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||
"""
|
||
开始请求
|
||
:param APPId: APPID
|
||
:param file_path: body里的文件路径
|
||
:return:
|
||
"""
|
||
|
||
global text_decoded
|
||
|
||
body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId)
|
||
#request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996'
|
||
|
||
#request_url = 'https://sut:80/sid/v1/private/s782b4996'
|
||
|
||
#headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId}
|
||
|
||
parsed_url = urlparse(sut_url)
|
||
headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId}
|
||
|
||
# 1. 首先测试服务健康检查
|
||
response = requests.get(f"{sut_url}/health")
|
||
print(response.status_code, response.text)
|
||
|
||
|
||
# 请求头
|
||
headers = {"Content-Type": "application/json"}
|
||
# 请求体(可指定限制处理的图片数量)
|
||
body = {"limit": 20 } # 可选参数,限制处理的图片总数
|
||
|
||
# 发送POST请求
|
||
response = requests.post(
|
||
f"{sut_url}/v1/private/s782b4996",
|
||
data=json.dumps(body),
|
||
headers=headers
|
||
)
|
||
|
||
# 解析响应结果
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print("预测评估结果:")
|
||
print(f"准确率: {result['metrics']['accuracy']}%")
|
||
print(f"平均召回率: {result['metrics']['average_recall']}%")
|
||
print(f"处理图片总数: {result['metrics']['total_images']}")
|
||
else:
|
||
print(f"请求失败,状态码: {response.status_code}")
|
||
print(f"错误信息: {response.text}")
|
||
|
||
|
||
|
||
|
||
# 添加基本认证信息
|
||
auth = ('llm', 'Rmf4#LcG(iFZrjU;2J')
|
||
#response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth)
|
||
|
||
#response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth)
|
||
#response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth)
|
||
"""
|
||
response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers)
|
||
"""
|
||
|
||
|
||
|
||
|
||
#print("HTTP状态码:", response.status_code)
|
||
#print("原始响应内容:", response.text) # 先打印原始内容
|
||
#print(f"请求URL: {sut_url + '/v1/private/s782b4996'}")
|
||
#print(f"请求headers: {headers}")
|
||
#print(f"请求body: {body}")
|
||
|
||
|
||
|
||
#tempResult = json.loads(response.content.decode('utf-8'))
|
||
#print(tempResult)
|
||
|
||
"""
|
||
# 对text字段进行Base64解码
|
||
if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']:
|
||
text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||
text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||
print(f"Base64解码后的text字段内容: {text_decoded}")
|
||
"""
|
||
|
||
#text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||
#text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||
#print(f"Base64解码后的text字段内容: {text_decoded}")
|
||
|
||
|
||
# 获取响应的 JSON 数据
|
||
result = response.json()
|
||
with open(RESULT_FILEPATH, "w") as f:
|
||
json.dump(result, f, indent=4, ensure_ascii=False)
|
||
print(f"结果已成功写入 {RESULT_FILEPATH}")
|
||
|
||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
||
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||
#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
||
|
||
from typing import Any, Dict, List
|
||
|
||
def result2file(
|
||
result: Dict[str, Any],
|
||
detail_cases: List[Dict[str, Any]] = None
|
||
):
|
||
assert result_filepath is not None
|
||
assert bad_cases_filepath is not None
|
||
#assert detailed_cases_filepath is not None
|
||
|
||
if result is not None:
|
||
with open(result_filepath, "w") as f:
|
||
json.dump(result, f, indent=4, ensure_ascii=False)
|
||
#if LOCAL_TEST:
|
||
# logger.info(f'result:\n {json.dumps(result, indent=4)}')
|
||
"""
|
||
if detail_cases is not None:
|
||
with open(detailed_cases_filepath, "w") as f:
|
||
json.dump(detail_cases, f, indent=4, ensure_ascii=False)
|
||
if LOCAL_TEST:
|
||
logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}')
|
||
"""
|
||
|
||
|
||
def test_image_prediction(sut_url, image_path):
|
||
"""发送单张图片到服务端预测"""
|
||
url = f"{sut_url}/v1/private/s782b4996"
|
||
|
||
try:
|
||
with open(image_path, 'rb') as f:
|
||
files = {'image': f}
|
||
response = requests.post(url, files=files, timeout=30)
|
||
|
||
result = response.json()
|
||
if result.get('status') != 'success':
|
||
return None, f"服务端错误: {result.get('message')}"
|
||
|
||
return result.get('top_prediction'), None
|
||
except Exception as e:
|
||
return None, f"请求错误: {str(e)}"
|
||
|
||
|
||
|
||
import random
|
||
import time
|
||
#from tqdm import tqdm
|
||
import os
|
||
import requests
|
||
|
||
if __name__ == '__main__':
|
||
|
||
print(f"\n===== main开始请求接口 ===============================================")
|
||
# 1. 首先测试服务健康检查
|
||
|
||
print(f"\n===== 服务健康检查 ===================================================")
|
||
response = requests.get(f"{sut_url}/health")
|
||
print(response.status_code, response.text)
|
||
|
||
"""
|
||
# 本地图片路径和真实标签(根据实际情况修改)
|
||
image_path = "/path/to/your/test_image.jpg"
|
||
true_label = "cat" # 图片的真实标签
|
||
"""
|
||
|
||
|
||
"""
|
||
# 请求头
|
||
headers = {"Content-Type": "application/json"}
|
||
# 请求体(可指定限制处理的图片数量)
|
||
body = {"limit": 20 } # 可选参数,限制处理的图片总数
|
||
|
||
# 发送POST请求
|
||
response = requests.post(
|
||
f"{sut_url}/v1/private/s782b4996",
|
||
data=json.dumps(body),
|
||
headers=headers
|
||
)
|
||
"""
|
||
|
||
"""
|
||
# 读取图片文件
|
||
with open(image_path, 'rb') as f:
|
||
files = {'image': f}
|
||
# 发送POST请求
|
||
response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files)
|
||
|
||
|
||
# 解析响应结果
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print("预测评估结果:")
|
||
print(f"准确率: {result['metrics']['accuracy']}%")
|
||
print(f"平均召回率: {result['metrics']['average_recall']}%")
|
||
print(f"处理图片总数: {result['metrics']['total_images']}")
|
||
else:
|
||
print(f"请求失败,状态码: {response.status_code}")
|
||
print(f"错误信息: {response.text}")
|
||
"""
|
||
|
||
|
||
###############################################################################################
|
||
dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录
|
||
samples_per_class = 3 # 每个类别抽取的样本数
|
||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式
|
||
|
||
# 结果统计变量
|
||
#total_samples = 0
|
||
correct_predictions = 0
|
||
|
||
# 结果统计变量
|
||
total_samples = 0
|
||
true_positives = 0
|
||
false_positives = 0
|
||
false_negatives = 0
|
||
total_processing_time = 0.0 # 总处理时间(秒)
|
||
|
||
"""
|
||
# 遍历所有类别文件夹
|
||
for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"):
|
||
folder_path = os.path.join(dataset_root, folder_name)
|
||
|
||
|
||
# 提取类别名(从"序号.name"格式中提取name部分)
|
||
class_name = folder_name.split('.', 1)[1].strip().lower()
|
||
|
||
# 获取文件夹中所有图片
|
||
image_files = []
|
||
for file in os.listdir(folder_path):
|
||
if file.lower().endswith(image_extensions):
|
||
image_files.append(os.path.join(folder_path, file))
|
||
|
||
# 随机抽取指定数量的图片(如果不足则取全部)
|
||
selected_images = random.sample(
|
||
image_files,
|
||
min(samples_per_class, len(image_files))
|
||
)
|
||
|
||
# 处理选中的图片
|
||
for img_path in selected_images:
|
||
total_count += 1
|
||
|
||
# 发送预测请求
|
||
prediction, error = test_image_prediction(sut_url, img_path)
|
||
if error:
|
||
print(f"处理图片 {img_path} 失败: {error}")
|
||
continue
|
||
|
||
# 解析预测结果
|
||
pred_class = prediction.get('class_name', '').lower()
|
||
confidence = prediction.get('confidence', 0)
|
||
|
||
# 判断是否预测正确(真实类别是否在预测类别中)
|
||
if class_name in pred_class:
|
||
correct_predictions += 1
|
||
|
||
|
||
# 可选:打印详细结果
|
||
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
|
||
"""
|
||
|
||
# 遍历所有类别文件夹
|
||
for folder_name in os.listdir(dataset_root):
|
||
folder_path = os.path.join(dataset_root, folder_name)
|
||
|
||
# 跳过非文件夹的项目
|
||
if not os.path.isdir(folder_path):
|
||
continue
|
||
|
||
# 提取类别名(从"序号.name"格式中提取name部分)
|
||
try:
|
||
class_name = folder_name.split('.', 1)[1].strip().lower()
|
||
except IndexError:
|
||
print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹")
|
||
continue
|
||
|
||
# 获取文件夹中所有图片
|
||
image_files = []
|
||
for file in os.listdir(folder_path):
|
||
file_path = os.path.join(folder_path, file)
|
||
if os.path.isfile(file_path) and file.lower().endswith(image_extensions):
|
||
image_files.append(file_path)
|
||
|
||
# 随机抽取指定数量的图片(如果不足则取全部)
|
||
selected_images = random.sample(
|
||
image_files,
|
||
min(samples_per_class, len(image_files))
|
||
)
|
||
|
||
# 处理该文件夹中的所有图片
|
||
for img_path in selected_images:
|
||
total_samples += 1
|
||
start_time = time.time() # 记录开始时间
|
||
# 发送预测请求
|
||
prediction, error = test_image_prediction(sut_url, img_path)
|
||
|
||
# 计算单张图片处理时间(包括网络请求和模型预测)
|
||
processing_time = time.time() - start_time
|
||
total_processing_time += processing_time
|
||
|
||
if error:
|
||
print(f"处理图片 {img_path} 失败: {error}")
|
||
# 处理失败的样本视为预测错误
|
||
false_negatives += 1
|
||
continue
|
||
|
||
# 解析预测结果
|
||
pred_class = prediction.get('class_name', '').lower()
|
||
confidence = prediction.get('confidence', 0)
|
||
|
||
# 判断是否预测正确(真实类别是否在预测类别中,不分大小写)
|
||
is_correct = class_name in pred_class
|
||
|
||
# 更新统计指标
|
||
if is_correct:
|
||
true_positives += 1
|
||
else:
|
||
false_positives += 1
|
||
false_negatives += 1
|
||
|
||
# 打印详细结果(可选)
|
||
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
|
||
|
||
"""
|
||
# 计算整体指标(在单标签场景下,准确率=召回率)
|
||
if total_samples == 0:
|
||
overall_accuracy = 0.0
|
||
overall_recall = 0.0
|
||
else:
|
||
overall_accuracy = correct_predictions / total_samples
|
||
overall_recall = correct_predictions / total_samples # 整体召回率
|
||
|
||
# 输出统计结果
|
||
print("\n" + "="*50)
|
||
print(f"测试总结:")
|
||
print(f"总测试样本数: {total_samples}")
|
||
print(f"正确预测样本数: {correct_predictions}")
|
||
print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
|
||
print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})")
|
||
print("="*50)
|
||
"""
|
||
# 初始化结果字典
|
||
result = {
|
||
"total_processing_time": round(total_processing_time, 6),
|
||
"throughput": 0.0,
|
||
"accuracy": 0.0,
|
||
"recall": 0.0
|
||
}
|
||
|
||
# 计算评估指标
|
||
if total_samples == 0:
|
||
print("没有找到任何图片样本")
|
||
|
||
|
||
# 准确率 = 正确预测的样本数 / 总预测样本数
|
||
accuracy = true_positives / total_samples * 100 if total_samples > 0 else 0
|
||
|
||
# 召回率 = 正确预测的样本数 / (正确预测的样本数 + 未正确预测的正样本数)
|
||
recall_denominator = true_positives + false_negatives
|
||
recall = true_positives / recall_denominator * 100 if recall_denominator > 0 else 0
|
||
|
||
# 处理速度计算(每秒钟处理的图片张数)
|
||
# 避免除以0(当总时间极短时)
|
||
throughput = total_samples / total_processing_time if total_processing_time > 1e-6 else 0
|
||
|
||
# 更新结果字典
|
||
result.update({
|
||
"throughput": round(throughput, 6),
|
||
"accuracy": round(accuracy, 6),
|
||
"recall": round(recall, 6)
|
||
})
|
||
|
||
# 打印最终统计结果
|
||
print("\n" + "="*50)
|
||
print(f"总样本数: {total_samples}")
|
||
print(f"总处理时间: {total_processing_time:.4f}秒")
|
||
print(f"处理速度: {throughput:.2f}张/秒") # 新增:每秒钟处理的图片张数
|
||
print(f"正确预测: {true_positives}")
|
||
print(f"错误预测: {total_samples - true_positives}")
|
||
print(f"准确率: {accuracy:.4f} ({true_positives}/{total_samples})")
|
||
print(f"召回率: {recall:.4f} ({true_positives}/{recall_denominator})")
|
||
print("="*50)
|
||
|
||
|
||
#result = {}
|
||
#result['accuracy_1_1'] = 3
|
||
result2file(result)
|
||
|
||
"""
|
||
if result['accuracy_1_1'] < 0.9:
|
||
log.error(f"1:1正确率未达到90%, 视为产品不可用")
|
||
change_product_unavailable()
|
||
|
||
|
||
if result['accuracy_1_N'] < 1:
|
||
log.error(f"1:N正确率未达到100%, 视为产品不可用")
|
||
change_product_unavailable()
|
||
if result['1_1_latency'] > 0.5:
|
||
log.error(f"1:1平均latency超过0.5s, 视为产品不可用")
|
||
change_product_unavailable()
|
||
if result['1_N_latency'] > 0.5:
|
||
log.error(f"1:N平均latency超过0.5s, 视为产品不可用")
|
||
change_product_unavailable()
|
||
if result['enroll_latency'] > 1:
|
||
log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用")
|
||
change_product_unavailable()
|
||
"""
|
||
exit_code = 0
|
||
|
||
|