Files
enginex-bi_series-vc-cnn/run_callback_cuda.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

1194 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
import time
import tempfile
import zipfile
import threading
from collections import defaultdict
from typing import Dict, List
import yaml
from pydantic import ValidationError
from schemas.dataset import QueryData
from utils.client_callback import ClientCallback, EvaluateResult, StopException
from utils.logger import log
from utils.service import register_sut
from utils.update_submit import change_product_available
from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml
from utils.leaderboard import change_product_unavailable
lck = threading.Lock()
# Environment variables by leaderboard
DATASET_FILEPATH = os.environ["DATASET_FILEPATH"]
RESULT_FILEPATH = os.environ["RESULT_FILEPATH"]
DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"]
SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"]
BENCHMARK_NAME = os.environ["BENCHMARK_NAME"]
TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1))
THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8))
log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}")
workspace_path = "/tmp/workspace"
# Environment variables by kubernetes
MY_POD_IP = os.environ["MY_POD_IP"]
# constants
RESOURCE_NAME = BENCHMARK_NAME
# Environment variables by judge_flow_config
LANG = os.getenv("lang")
SUT_CPU = os.getenv("SUT_CPU", "2")
SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi")
SUT_VGPU = os.getenv("SUT_VGPU", "1")
#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU)))
#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU)))
SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100")
RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu")
assert RESOURCE_TYPE in [
"cpu",
"vgpu",
], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu"
unzip_dir(DATASET_FILEPATH, workspace_path)
def get_sut_url_kubernetes():
with open(SUBMIT_CONFIG_FILEPATH, "r") as f:
submit_config = yaml.safe_load(f)
assert isinstance(submit_config, dict)
submit_config.setdefault("values", {})
submit_config["values"]["containers"] = [
{
"name": "corex-container",
"image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像
"command": ["sleep"], # 替换为你的模型启动命令使用python解释器
"args": ["3600"], # 替换为你的模型参数,运行我的推理脚本
# 添加存储卷挂载
#"volumeMounts": [
# {
# "name": "model-volume",
# "mountPath": "/model" # 挂载到/model目录
# }
#]
}
]
"""
# 添加存储卷配置
submit_config["values"]["volumes"] = [
{
"name": "model-volume",
"persistentVolumeClaim": {
"claimName": "sid-model-pvc" # 使用已有的PVC
}
}
]
"""
"""
# Inject specified cpu and memory
resource = {
"cpu": SUT_CPU,
"memory": SUT_MEMORY,
}
"""
submit_config["values"]["resources"] = {
"requests":{},
"limits": {},
}
limits = submit_config["values"]["resources"]["limits"]
requests = submit_config["values"]["resources"]["requests"]
"""
# ########## 关键修改替换为iluvatar GPU配置 ##########
if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU
# 替换nvidia资源键为iluvatar.ai/gpu
vgpu_resource = {
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
# 若需要其他资源如显存按你的K8s配置补充例如
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
}
limits.update(vgpu_resource)
requests.update(vgpu_resource)
# 节点选择器替换为你的accelerator标签
submit_config["values"]["nodeSelector"] = {
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
}
# 容忍度替换为你的tolerations配置
submit_config["values"]["tolerations"] = [
{
"key": "hosttype",
"operator": "Equal",
"value": "iluvatar",
"effect": "NoSchedule",
}
]
# #########################################
# 禁止CPU模式下使用GPU资源保持原逻辑
else:
if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests:
log.error("禁止在CPU模式下使用GPU资源")
sys.exit(1)
#gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键
#for key in gpukeys:
# if key in limits or key in requests:
# log.error("禁止使用vgpu资源")
# sys.exit(1)
"""
# 替换nvidia资源键为iluvatar.ai/gpu
vgpu_resource = {
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
# 若需要其他资源如显存按你的K8s配置补充例如
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
}
limits.update(vgpu_resource)
requests.update(vgpu_resource)
# 节点选择器替换为你的accelerator标签
submit_config["values"]["nodeSelector"] = {
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
}
# 容忍度替换为你的tolerations配置
"""
submit_config["values"]["tolerations"] = [
{
"key": "hosttype",
"operator": "Equal",
"value": "iluvatar",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "arm64",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "myinit",
"effect": "NoSchedule",
},
{
"key": "hosttype",
"operator": "Equal",
"value": "middleware",
"effect": "NoSchedule",
}
]
"""
"""
{
"key": "node-role.kubernetes.io/master",
"operator": "Exists",
"effect": "NoSchedule",
},
{
"key": "node.kubernetes.io/not-ready",
"operator": "Exists",
"effect": "NoExecute",
"tolerationSeconds": 300
},
{
"key": "node.kubernetes.io/unreachable",
"operator": "Exists",
"effect": "NoExecute",
"tolerationSeconds": 300
}
"""
log.info(f"submit_config: {submit_config}")
log.info(f"RESOURCE_NAME: {RESOURCE_NAME}")
return register_sut(submit_config, RESOURCE_NAME).replace(
"ws://", "http://"
)
def get_sut_url():
return get_sut_url_kubernetes()
#SUT_URL = get_sut_url()
#os.environ["SUT_URL"] = SUT_URL
"""
def load_dataset(
dataset_filepath: str,
) -> Dict[str, List[QueryData]]:
dataset_path = tempfile.mkdtemp()
with zipfile.ZipFile(dataset_filepath) as zf:
zf.extractall(dataset_path)
basename = os.path.basename(dataset_filepath)
datayaml = os.path.join(dataset_path, "data.yaml")
if not os.path.exists(datayaml):
sub_dataset_paths = os.listdir(dataset_path)
dataset = {}
for sub_dataset_path in sub_dataset_paths:
sub_dataset = load_dataset(
os.path.join(dataset_path, sub_dataset_path)
)
for k, v in sub_dataset.items():
k = os.path.join(basename, k)
dataset[k] = v
return dataset
with open(datayaml, "r") as f:
data = yaml.safe_load(f)
assert isinstance(data, dict)
lang = LANG
data_lang = data.get("global", {}).get("lang")
if lang is None and data_lang is not None:
if data_lang is not None:
# 使用配置中的语言类型
lang = data_lang
if lang is None and basename.startswith("asr.") and len(basename) == 4 + 2:
# 数据集名称为asr.en 可以认为语言为en
lang = basename[4:]
if lang is None:
log.error(
"数据集错误 通过data.yaml中的 global.lang 或 数据集名称 asr.xx 指定语言类型"
)
sys.exit(1)
query_data = data.get("query_data", [])
audio_size_map = {}
for query in query_data:
query["lang"] = lang
query["file"] = os.path.join(dataset_path, query["file"])
audio_size_map[query["file"]] = os.path.getsize(query["file"])
# 根据音频大小排序
query_data = sorted(
query_data, key=lambda x: audio_size_map[x["file"]], reverse=True
)
valid_query_data = []
for i_query_data in query_data:
try:
valid_query_data.append(QueryData.model_validate(i_query_data))
except ValidationError:
log.error("数据集错误 数据中query_data格式错误")
sys.exit(1)
return {
basename: valid_query_data,
}
def merge_query_data(dataset: Dict[str, List[QueryData]]) -> List[QueryData]:
query_datas = []
for query_data in dataset.values():
query_datas.extend(query_data)
return query_datas
def run_one_predict(
client: ClientCallback, query_data: QueryData, task_id: str
) -> EvaluateResult:
try:
client.predict(None, query_data.file, query_data.duration, task_id)
except StopException:
sys.exit(1)
client.finished.wait()
if client.error is not None:
sys.exit(1)
client.app_on = False
try:
with lck:
ret = client.evaluate(query_data)
return ret
except StopException:
sys.exit(1)
"""
"""
def predict_task(
client: ClientCallback, task_id: int, query_data: QueryData, test_results: list
):
log.info(f"Task-{task_id}开始评测")
test_results[task_id] = run_one_predict(client, query_data, str(task_id))
def merge_concurrent_result(evaluate_results: List[EvaluateResult]) -> Dict:
cer = 0.0
align_start = {}
align_end = {}
first_word_distance_sum = 0.0
last_word_distance_sum = 0.0
rtf = 0.0
first_receive_delay: float = 0.0
query_count: int = 0
voice_count: int = 0
pred_punctuation_num: int = 0
label_punctuation_num: int = 0
pred_sentence_punctuation_num: int = 0
label_setence_punctuation_num: int = 0
for evalute_result in evaluate_results:
cer += evalute_result.cer
for k, v in evalute_result.align_start.items():
align_start.setdefault(k, 0)
align_start[k] += v
for k, v in evalute_result.align_end.items():
align_end.setdefault(k, 0)
align_end[k] += v
first_word_distance_sum += evalute_result.first_word_distance_sum
last_word_distance_sum += evalute_result.last_word_distance_sum
rtf += evalute_result.rtf
first_receive_delay += evalute_result.first_receive_delay
query_count += evalute_result.query_count
voice_count += evalute_result.voice_count
pred_punctuation_num += evalute_result.pred_punctuation_num
label_punctuation_num += evalute_result.label_punctuation_num
pred_sentence_punctuation_num += (
evalute_result.pred_sentence_punctuation_num
)
label_setence_punctuation_num += (
evalute_result.label_setence_punctuation_num
)
lens = len(evaluate_results)
cer /= lens
for k, v in align_start.items():
align_start[k] /= voice_count
for k, v in align_end.items():
align_end[k] /= voice_count
first_word_distance = first_word_distance_sum / voice_count
last_word_distance = last_word_distance_sum / voice_count
rtf /= lens
first_receive_delay /= lens
json_result = {
"one_minus_cer": 1 - cer,
"first_word_distance_mean": first_word_distance,
"last_word_distance_mean": last_word_distance,
"query_count": query_count // lens,
"voice_count": voice_count // lens,
"rtf": rtf,
"first_receive_delay": first_receive_delay,
"punctuation_ratio": (
pred_punctuation_num / label_punctuation_num
if label_punctuation_num > 0
else 1.0
),
"sentence_punctuation_ratio": (
pred_sentence_punctuation_num / label_setence_punctuation_num
if label_setence_punctuation_num > 0
else 1.0
),
}
for k, v in align_start.items():
json_result["start_word_%dms_ratio" % k] = v
for k, v in align_end.items():
json_result["end_word_%dms_ratio" % k] = v
return json_result
def merge_result(result: Dict[str, List[EvaluateResult]]) -> Dict:
json_result = {}
for lang, evaluate_results in result.items():
if len(evaluate_results) == 0:
continue
cer = 0.0
align_start = {}
align_end = {}
first_word_distance_sum = 0.0
last_word_distance_sum = 0.0
rtf = 0.0
first_receive_delay: float = 0.0
query_count: int = 0
voice_count: int = 0
pred_punctuation_num: int = 0
label_punctuation_num: int = 0
pred_sentence_punctuation_num: int = 0
label_setence_punctuation_num: int = 0
for evalute_result in evaluate_results:
cer += evalute_result.cer
for k, v in evalute_result.align_start.items():
align_start.setdefault(k, 0)
align_start[k] += v
for k, v in evalute_result.align_end.items():
align_end.setdefault(k, 0)
align_end[k] += v
first_word_distance_sum += evalute_result.first_word_distance_sum
last_word_distance_sum += evalute_result.last_word_distance_sum
rtf += evalute_result.rtf
first_receive_delay += evalute_result.first_receive_delay
query_count += evalute_result.query_count
voice_count += evalute_result.voice_count
pred_punctuation_num += evalute_result.pred_punctuation_num
label_punctuation_num += evalute_result.label_punctuation_num
pred_sentence_punctuation_num += (
evalute_result.pred_sentence_punctuation_num
)
label_setence_punctuation_num += (
evalute_result.label_setence_punctuation_num
)
lens = len(evaluate_results)
cer /= lens
for k, v in align_start.items():
align_start[k] /= voice_count
for k, v in align_end.items():
align_end[k] /= voice_count
first_word_distance = first_word_distance_sum / voice_count
last_word_distance = last_word_distance_sum / voice_count
rtf /= lens
first_receive_delay /= lens
lang_result = {
"one_minus_cer": 1 - cer,
"first_word_distance_mean": first_word_distance,
"last_word_distance_mean": last_word_distance,
"query_count": 1,
"voice_count": voice_count,
"rtf": rtf,
"first_receive_delay": first_receive_delay,
"punctuation_ratio": (
pred_punctuation_num / label_punctuation_num
if label_punctuation_num > 0
else 1.0
),
"sentence_punctuation_ratio": (
pred_sentence_punctuation_num / label_setence_punctuation_num
if label_setence_punctuation_num > 0
else 1.0
),
}
for k, v in align_start.items():
lang_result["start_word_%dms_ratio" % k] = v
for k, v in align_end.items():
lang_result["end_word_%dms_ratio" % k] = v
if lang == "":
json_result.update(lang_result)
else:
json_result[lang] = lang_result
return json_result
"""
"""
def main():
log.info(f'{TEST_CONCURRENCY=}, {THRESHOLD_OMCER=}')
dataset = load_dataset(DATASET_FILEPATH)
query_datas = merge_query_data(dataset)
#获取 ASR 服务 URL通常从 Kubernetes 配置)
sut_url = get_sut_url()
#创建多个客户端实例(每个客户端监听不同端口,如 80、81、82...
port_base = 80
clients = [ClientCallback(sut_url, port_base + i) for i in range(TEST_CONCURRENCY)]
#准备测试数据与线程
detail_cases = []
# we use the same test data for all requests
query_data = query_datas[0]
test_results = [None] * len(clients)
test_threads = [threading.Thread(target=predict_task, args=(client, task_id, query_data, test_results))
for task_id, client in enumerate(clients)]
#启动并发测试启动线程并间隔10秒设置超时时间为1小时
for t in test_threads:
t.start()
time.sleep(10)
[t.join(timeout=3600) for t in test_threads]
#合并结果与评估
final_result = merge_concurrent_result(test_results)
product_avaiable = all([c.product_avaiable for c in clients])
final_result['concurrent_req'] = TEST_CONCURRENCY
if final_result['one_minus_cer'] < THRESHOLD_OMCER:
product_avaiable = False
if not product_avaiable:
final_result['success'] = False
change_product_available()
else:
final_result['success'] = True
#保存结果,
log.info(
"指标结果为: %s", json.dumps(final_result, indent=2, ensure_ascii=False)
)
time.sleep(120)
#打印并保存最终结果到文件
with open(RESULT_FILEPATH, "w") as f:
json.dump(final_result, f, indent=2, ensure_ascii=False)
#保存详细测试用例结果
with open(DETAILED_CASES_FILEPATH, "w") as f:
json.dump(detail_cases, f, indent=2, ensure_ascii=False)
"""
#############################################################################
import requests
import base64
def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
"""
生成请求的body
:param apiname
:param APPId: Appid
:param file_name: 文件路径
:return:
"""
if apiname == 'createFeature':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "createFeature",
"groupId": "test_voiceprint_e",
"featureId": featureId,
"featureInfo": featureInfo,
"createFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'createGroup':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "createGroup",
"groupId": "test_voiceprint_e",
"groupName": "vip_user",
"groupInfo": "store_vip_user_voiceprint",
"createGroupRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'deleteFeature':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "deleteFeature",
"groupId": "iFLYTEK_examples_groupId",
"featureId": "iFLYTEK_examples_featureId",
"deleteFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'queryFeatureList':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "queryFeatureList",
"groupId": "user_voiceprint_2",
"queryFeatureListRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
elif apiname == 'searchFea':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "searchFea",
"groupId": "test_voiceprint_e",
"topK": 1,
"searchFeaRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'searchScoreFea':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "searchScoreFea",
"groupId": "test_voiceprint_e",
"dstFeatureId": dstFeatureId,
"searchScoreFeaRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'updateFeature':
with open(file_path, "rb") as f:
audioBytes = f.read()
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "updateFeature",
"groupId": "iFLYTEK_examples_groupId",
"featureId": "iFLYTEK_examples_featureId",
"featureInfo": "iFLYTEK_examples_featureInfo_update",
"updateFeatureRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
},
"payload": {
"resource": {
"encoding": "lame",
"sample_rate": 16000,
"channels": 1,
"bit_depth": 16,
"status": 3,
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
}
}
}
elif apiname == 'deleteGroup':
body = {
"header": {
"app_id": APPId,
"status": 3
},
"parameter": {
"s782b4996": {
"func": "deleteGroup",
"groupId": "iFLYTEK_examples_groupId",
"deleteGroupRes": {
"encoding": "utf8",
"compress": "raw",
"format": "json"
}
}
}
}
else:
raise Exception(
"输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查")
return body
log.info(f"开始请求获取到SUT服务URL")
# 获取SUT服务URL
sut_url = get_sut_url()
print(f"获取到的SUT_URL: {sut_url}") # 调试输出
log.info(f"获取到SUT服务URL: {sut_url}")
from urllib.parse import urlparse
# 全局变量
text_decoded = None
###################################新增新增################################
def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
"""
开始请求
:param APPId: APPID
:param file_path: body里的文件路径
:return:
"""
global text_decoded
body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId)
#request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996'
#request_url = 'https://sut:80/sid/v1/private/s782b4996'
#headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId}
parsed_url = urlparse(sut_url)
headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId}
# 1. 首先测试服务健康检查
response = requests.get(f"{sut_url}/health")
print(response.status_code, response.text)
# 请求头
headers = {"Content-Type": "application/json"}
# 请求体(可指定限制处理的图片数量)
body = {"limit": 20 } # 可选参数,限制处理的图片总数
# 发送POST请求
response = requests.post(
f"{sut_url}/v1/private/s782b4996",
data=json.dumps(body),
headers=headers
)
# 解析响应结果
if response.status_code == 200:
result = response.json()
print("预测评估结果:")
print(f"准确率: {result['metrics']['accuracy']}%")
print(f"平均召回率: {result['metrics']['average_recall']}%")
print(f"处理图片总数: {result['metrics']['total_images']}")
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"错误信息: {response.text}")
# 添加基本认证信息
auth = ('llm', 'Rmf4#LcG(iFZrjU;2J')
#response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth)
#response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth)
#response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth)
"""
response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers)
"""
#print("HTTP状态码:", response.status_code)
#print("原始响应内容:", response.text) # 先打印原始内容
#print(f"请求URL: {sut_url + '/v1/private/s782b4996'}")
#print(f"请求headers: {headers}")
#print(f"请求body: {body}")
#tempResult = json.loads(response.content.decode('utf-8'))
#print(tempResult)
"""
# 对text字段进行Base64解码
if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']:
text_encoded = tempResult['payload']['updateFeatureRes']['text']
text_decoded = base64.b64decode(text_encoded).decode('utf-8')
print(f"Base64解码后的text字段内容: {text_decoded}")
"""
#text_encoded = tempResult['payload']['updateFeatureRes']['text']
#text_decoded = base64.b64decode(text_encoded).decode('utf-8')
#print(f"Base64解码后的text字段内容: {text_decoded}")
# 获取响应的 JSON 数据
result = response.json()
with open(RESULT_FILEPATH, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"结果已成功写入 {RESULT_FILEPATH}")
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
from typing import Any, Dict, List
def result2file(
result: Dict[str, Any],
detail_cases: List[Dict[str, Any]] = None
):
assert result_filepath is not None
assert bad_cases_filepath is not None
#assert detailed_cases_filepath is not None
if result is not None:
with open(result_filepath, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
#if LOCAL_TEST:
# logger.info(f'result:\n {json.dumps(result, indent=4)}')
"""
if detail_cases is not None:
with open(detailed_cases_filepath, "w") as f:
json.dump(detail_cases, f, indent=4, ensure_ascii=False)
if LOCAL_TEST:
logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}')
"""
def test_image_prediction(sut_url, image_path):
"""发送单张图片到服务端预测"""
url = f"{sut_url}/v1/private/s782b4996"
try:
with open(image_path, 'rb') as f:
files = {'image': f}
response = requests.post(url, files=files, timeout=30)
result = response.json()
if result.get('status') != 'success':
return None, f"服务端错误: {result.get('message')}"
return result.get('top_prediction'), None
except Exception as e:
return None, f"请求错误: {str(e)}"
import random
import time
#from tqdm import tqdm
import os
import requests
if __name__ == '__main__':
print(f"\n===== main开始请求接口 ===============================================")
# 1. 首先测试服务健康检查
print(f"\n===== 服务健康检查 ===================================================")
response = requests.get(f"{sut_url}/health")
print(response.status_code, response.text)
"""
# 本地图片路径和真实标签(根据实际情况修改)
image_path = "/path/to/your/test_image.jpg"
true_label = "cat" # 图片的真实标签
"""
"""
# 请求头
headers = {"Content-Type": "application/json"}
# 请求体(可指定限制处理的图片数量)
body = {"limit": 20 } # 可选参数,限制处理的图片总数
# 发送POST请求
response = requests.post(
f"{sut_url}/v1/private/s782b4996",
data=json.dumps(body),
headers=headers
)
"""
"""
# 读取图片文件
with open(image_path, 'rb') as f:
files = {'image': f}
# 发送POST请求
response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files)
# 解析响应结果
if response.status_code == 200:
result = response.json()
print("预测评估结果:")
print(f"准确率: {result['metrics']['accuracy']}%")
print(f"平均召回率: {result['metrics']['average_recall']}%")
print(f"处理图片总数: {result['metrics']['total_images']}")
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"错误信息: {response.text}")
"""
###############################################################################################
dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录
samples_per_class = 3 # 每个类别抽取的样本数
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式
# 结果统计变量
#total_samples = 0
correct_predictions = 0
# 结果统计变量
total_samples = 0
true_positives = 0
false_positives = 0
false_negatives = 0
total_processing_time = 0.0 # 总处理时间(秒)
"""
# 遍历所有类别文件夹
for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"):
folder_path = os.path.join(dataset_root, folder_name)
# 提取类别名(从"序号.name"格式中提取name部分
class_name = folder_name.split('.', 1)[1].strip().lower()
# 获取文件夹中所有图片
image_files = []
for file in os.listdir(folder_path):
if file.lower().endswith(image_extensions):
image_files.append(os.path.join(folder_path, file))
# 随机抽取指定数量的图片(如果不足则取全部)
selected_images = random.sample(
image_files,
min(samples_per_class, len(image_files))
)
# 处理选中的图片
for img_path in selected_images:
total_count += 1
# 发送预测请求
prediction, error = test_image_prediction(sut_url, img_path)
if error:
print(f"处理图片 {img_path} 失败: {error}")
continue
# 解析预测结果
pred_class = prediction.get('class_name', '').lower()
confidence = prediction.get('confidence', 0)
# 判断是否预测正确(真实类别是否在预测类别中)
if class_name in pred_class:
correct_predictions += 1
# 可选:打印详细结果
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
"""
# 遍历所有类别文件夹
for folder_name in os.listdir(dataset_root):
folder_path = os.path.join(dataset_root, folder_name)
# 跳过非文件夹的项目
if not os.path.isdir(folder_path):
continue
# 提取类别名(从"序号.name"格式中提取name部分
try:
class_name = folder_name.split('.', 1)[1].strip().lower()
except IndexError:
print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹")
continue
# 获取文件夹中所有图片
image_files = []
for file in os.listdir(folder_path):
file_path = os.path.join(folder_path, file)
if os.path.isfile(file_path) and file.lower().endswith(image_extensions):
image_files.append(file_path)
# 随机抽取指定数量的图片(如果不足则取全部)
selected_images = random.sample(
image_files,
min(samples_per_class, len(image_files))
)
# 处理该文件夹中的所有图片
for img_path in selected_images:
total_samples += 1
start_time = time.time() # 记录开始时间
# 发送预测请求
prediction, error = test_image_prediction(sut_url, img_path)
# 计算单张图片处理时间(包括网络请求和模型预测)
processing_time = time.time() - start_time
total_processing_time += processing_time
if error:
print(f"处理图片 {img_path} 失败: {error}")
# 处理失败的样本视为预测错误
false_negatives += 1
continue
# 解析预测结果
pred_class = prediction.get('class_name', '').lower()
confidence = prediction.get('confidence', 0)
# 判断是否预测正确(真实类别是否在预测类别中,不分大小写)
is_correct = class_name in pred_class
# 更新统计指标
if is_correct:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
# 打印详细结果(可选)
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}")
"""
# 计算整体指标(在单标签场景下,准确率=召回率)
if total_samples == 0:
overall_accuracy = 0.0
overall_recall = 0.0
else:
overall_accuracy = correct_predictions / total_samples
overall_recall = correct_predictions / total_samples # 整体召回率
# 输出统计结果
print("\n" + "="*50)
print(f"测试总结:")
print(f"总测试样本数: {total_samples}")
print(f"正确预测样本数: {correct_predictions}")
print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})")
print("="*50)
"""
# 初始化结果字典
result = {
"total_processing_time": round(total_processing_time, 6),
"throughput": 0.0,
"accuracy": 0.0,
"recall": 0.0
}
# 计算评估指标
if total_samples == 0:
print("没有找到任何图片样本")
# 准确率 = 正确预测的样本数 / 总预测样本数
accuracy = true_positives / total_samples * 100 if total_samples > 0 else 0
# 召回率 = 正确预测的样本数 / (正确预测的样本数 + 未正确预测的正样本数)
recall_denominator = true_positives + false_negatives
recall = true_positives / recall_denominator * 100 if recall_denominator > 0 else 0
# 处理速度计算(每秒钟处理的图片张数)
# 避免除以0当总时间极短时
throughput = total_samples / total_processing_time if total_processing_time > 1e-6 else 0
# 更新结果字典
result.update({
"throughput": round(throughput, 6),
"accuracy": round(accuracy, 6),
"recall": round(recall, 6)
})
# 打印最终统计结果
print("\n" + "="*50)
print(f"总样本数: {total_samples}")
print(f"总处理时间: {total_processing_time:.4f}")
print(f"处理速度: {throughput:.2f}张/秒") # 新增:每秒钟处理的图片张数
print(f"正确预测: {true_positives}")
print(f"错误预测: {total_samples - true_positives}")
print(f"准确率: {accuracy:.4f} ({true_positives}/{total_samples})")
print(f"召回率: {recall:.4f} ({true_positives}/{recall_denominator})")
print("="*50)
#result = {}
#result['accuracy_1_1'] = 3
result2file(result)
"""
if result['accuracy_1_1'] < 0.9:
log.error(f"1:1正确率未达到90%, 视为产品不可用")
change_product_unavailable()
if result['accuracy_1_N'] < 1:
log.error(f"1:N正确率未达到100%, 视为产品不可用")
change_product_unavailable()
if result['1_1_latency'] > 0.5:
log.error(f"1:1平均latency超过0.5s, 视为产品不可用")
change_product_unavailable()
if result['1_N_latency'] > 0.5:
log.error(f"1:N平均latency超过0.5s, 视为产品不可用")
change_product_unavailable()
if result['enroll_latency'] > 1:
log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用")
change_product_unavailable()
"""
exit_code = 0