720 lines
23 KiB
Python
720 lines
23 KiB
Python
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import tempfile
|
||
import zipfile
|
||
import threading
|
||
from collections import defaultdict
|
||
from typing import Dict, List
|
||
|
||
import yaml
|
||
from pydantic import ValidationError
|
||
|
||
from schemas.dataset import QueryData
|
||
from utils.client_callback import ClientCallback, EvaluateResult, StopException
|
||
from utils.logger import log
|
||
from utils.service import register_sut
|
||
from utils.update_submit import change_product_available
|
||
from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml
|
||
from utils.leaderboard import change_product_unavailable
|
||
|
||
|
||
lck = threading.Lock()
|
||
|
||
# Environment variables by leaderboard
|
||
DATASET_FILEPATH = os.environ["DATASET_FILEPATH"]
|
||
RESULT_FILEPATH = os.environ["RESULT_FILEPATH"]
|
||
|
||
DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"]
|
||
SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"]
|
||
BENCHMARK_NAME = os.environ["BENCHMARK_NAME"]
|
||
TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1))
|
||
THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8))
|
||
|
||
log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}")
|
||
workspace_path = "/tmp/workspace"
|
||
|
||
|
||
# Environment variables by kubernetes
|
||
MY_POD_IP = os.environ["MY_POD_IP"]
|
||
|
||
# constants
|
||
RESOURCE_NAME = BENCHMARK_NAME
|
||
|
||
# Environment variables by judge_flow_config
|
||
LANG = os.getenv("lang")
|
||
SUT_CPU = os.getenv("SUT_CPU", "2")
|
||
SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi")
|
||
SUT_VGPU = os.getenv("SUT_VGPU", "1")
|
||
#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU)))
|
||
#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU)))
|
||
SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100")
|
||
RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu")
|
||
assert RESOURCE_TYPE in [
|
||
"cpu",
|
||
"vgpu",
|
||
], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu"
|
||
|
||
|
||
unzip_dir(DATASET_FILEPATH, workspace_path)
|
||
|
||
def get_sut_url_kubernetes():
|
||
with open(SUBMIT_CONFIG_FILEPATH, "r") as f:
|
||
submit_config = yaml.safe_load(f)
|
||
assert isinstance(submit_config, dict)
|
||
|
||
submit_config.setdefault("values", {})
|
||
|
||
submit_config["values"]["containers"] = [
|
||
{
|
||
"name": "corex-container",
|
||
"image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像
|
||
"command": ["sleep"], # 替换为你的模型启动命令,使用python解释器
|
||
"args": ["3600"], # 替换为你的模型参数,运行我的推理脚本
|
||
|
||
# 添加存储卷挂载
|
||
#"volumeMounts": [
|
||
# {
|
||
# "name": "model-volume",
|
||
# "mountPath": "/model" # 挂载到/model目录
|
||
# }
|
||
#]
|
||
}
|
||
]
|
||
|
||
"""
|
||
# 添加存储卷配置
|
||
submit_config["values"]["volumes"] = [
|
||
{
|
||
"name": "model-volume",
|
||
"persistentVolumeClaim": {
|
||
"claimName": "sid-model-pvc" # 使用已有的PVC
|
||
}
|
||
}
|
||
]
|
||
"""
|
||
|
||
"""
|
||
# Inject specified cpu and memory
|
||
resource = {
|
||
"cpu": SUT_CPU,
|
||
"memory": SUT_MEMORY,
|
||
}
|
||
"""
|
||
submit_config["values"]["resources"] = {
|
||
"requests":{},
|
||
"limits": {},
|
||
}
|
||
|
||
limits = submit_config["values"]["resources"]["limits"]
|
||
requests = submit_config["values"]["resources"]["requests"]
|
||
|
||
|
||
|
||
# 替换nvidia资源键为iluvatar.ai/gpu
|
||
vgpu_resource = {
|
||
"iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键
|
||
# 若需要其他资源(如显存),按你的K8s配置补充,例如:
|
||
# "iluvatar.ai/gpumem": SUT_VGPU_MEM,
|
||
}
|
||
limits.update(vgpu_resource)
|
||
requests.update(vgpu_resource)
|
||
# 节点选择器:替换为你的accelerator标签
|
||
submit_config["values"]["nodeSelector"] = {
|
||
"contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签
|
||
}
|
||
# 容忍度:替换为你的tolerations配置
|
||
|
||
|
||
|
||
|
||
log.info(f"submit_config: {submit_config}")
|
||
log.info(f"RESOURCE_NAME: {RESOURCE_NAME}")
|
||
|
||
return register_sut(submit_config, RESOURCE_NAME).replace(
|
||
"ws://", "http://"
|
||
)
|
||
|
||
|
||
def get_sut_url():
|
||
return get_sut_url_kubernetes()
|
||
|
||
#SUT_URL = get_sut_url()
|
||
#os.environ["SUT_URL"] = SUT_URL
|
||
|
||
|
||
|
||
#############################################################################
|
||
|
||
import requests
|
||
import base64
|
||
|
||
def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||
"""
|
||
生成请求的body
|
||
:param apiname
|
||
:param APPId: Appid
|
||
:param file_name: 文件路径
|
||
:return:
|
||
"""
|
||
if apiname == 'createFeature':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "createFeature",
|
||
"groupId": "test_voiceprint_e",
|
||
"featureId": featureId,
|
||
"featureInfo": featureInfo,
|
||
"createFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'createGroup':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "createGroup",
|
||
"groupId": "test_voiceprint_e",
|
||
"groupName": "vip_user",
|
||
"groupInfo": "store_vip_user_voiceprint",
|
||
"createGroupRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'deleteFeature':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "deleteFeature",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"featureId": "iFLYTEK_examples_featureId",
|
||
"deleteFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'queryFeatureList':
|
||
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "queryFeatureList",
|
||
"groupId": "user_voiceprint_2",
|
||
"queryFeatureListRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'searchFea':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "searchFea",
|
||
"groupId": "test_voiceprint_e",
|
||
"topK": 1,
|
||
"searchFeaRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'searchScoreFea':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "searchScoreFea",
|
||
"groupId": "test_voiceprint_e",
|
||
"dstFeatureId": dstFeatureId,
|
||
"searchScoreFeaRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'updateFeature':
|
||
|
||
with open(file_path, "rb") as f:
|
||
audioBytes = f.read()
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "updateFeature",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"featureId": "iFLYTEK_examples_featureId",
|
||
"featureInfo": "iFLYTEK_examples_featureInfo_update",
|
||
"updateFeatureRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"resource": {
|
||
"encoding": "lame",
|
||
"sample_rate": 16000,
|
||
"channels": 1,
|
||
"bit_depth": 16,
|
||
"status": 3,
|
||
"audio": str(base64.b64encode(audioBytes), 'UTF-8')
|
||
}
|
||
}
|
||
}
|
||
elif apiname == 'deleteGroup':
|
||
body = {
|
||
"header": {
|
||
"app_id": APPId,
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"s782b4996": {
|
||
"func": "deleteGroup",
|
||
"groupId": "iFLYTEK_examples_groupId",
|
||
"deleteGroupRes": {
|
||
"encoding": "utf8",
|
||
"compress": "raw",
|
||
"format": "json"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
else:
|
||
raise Exception(
|
||
"输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查")
|
||
return body
|
||
|
||
|
||
|
||
log.info(f"开始请求获取到SUT服务URL")
|
||
# 获取SUT服务URL
|
||
sut_url = get_sut_url()
|
||
print(f"获取到的SUT_URL: {sut_url}") # 调试输出
|
||
log.info(f"获取到SUT服务URL: {sut_url}")
|
||
|
||
from urllib.parse import urlparse
|
||
|
||
# 全局变量
|
||
text_decoded = None
|
||
|
||
###################################新增新增################################
|
||
def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None):
|
||
"""
|
||
开始请求
|
||
:param APPId: APPID
|
||
:param file_path: body里的文件路径
|
||
:return:
|
||
"""
|
||
|
||
global text_decoded
|
||
|
||
body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId)
|
||
#request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996'
|
||
|
||
#request_url = 'https://sut:80/sid/v1/private/s782b4996'
|
||
|
||
#headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId}
|
||
|
||
parsed_url = urlparse(sut_url)
|
||
headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId}
|
||
|
||
# 1. 首先测试服务健康检查
|
||
response = requests.get(f"{sut_url}/health")
|
||
print(response.status_code, response.text)
|
||
|
||
|
||
# 请求头
|
||
headers = {"Content-Type": "application/json"}
|
||
# 请求体(可指定限制处理的图片数量)
|
||
body = {"limit": 20 } # 可选参数,限制处理的图片总数
|
||
|
||
# 发送POST请求
|
||
response = requests.post(
|
||
f"{sut_url}/v1/private/s782b4996",
|
||
data=json.dumps(body),
|
||
headers=headers
|
||
)
|
||
|
||
# 解析响应结果
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print("预测评估结果:")
|
||
print(f"准确率: {result['metrics']['accuracy']}%")
|
||
print(f"平均召回率: {result['metrics']['average_recall']}%")
|
||
print(f"处理图片总数: {result['metrics']['total_images']}")
|
||
else:
|
||
print(f"请求失败,状态码: {response.status_code}")
|
||
print(f"错误信息: {response.text}")
|
||
|
||
|
||
|
||
|
||
# 添加基本认证信息
|
||
auth = ('llm', 'Rmf4#LcG(iFZrjU;2J')
|
||
#response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth)
|
||
|
||
#response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth)
|
||
#response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth)
|
||
"""
|
||
response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers)
|
||
"""
|
||
|
||
|
||
|
||
|
||
#print("HTTP状态码:", response.status_code)
|
||
#print("原始响应内容:", response.text) # 先打印原始内容
|
||
#print(f"请求URL: {sut_url + '/v1/private/s782b4996'}")
|
||
#print(f"请求headers: {headers}")
|
||
#print(f"请求body: {body}")
|
||
|
||
|
||
|
||
#tempResult = json.loads(response.content.decode('utf-8'))
|
||
#print(tempResult)
|
||
|
||
"""
|
||
# 对text字段进行Base64解码
|
||
if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']:
|
||
text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||
text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||
print(f"Base64解码后的text字段内容: {text_decoded}")
|
||
"""
|
||
|
||
#text_encoded = tempResult['payload']['updateFeatureRes']['text']
|
||
#text_decoded = base64.b64decode(text_encoded).decode('utf-8')
|
||
#print(f"Base64解码后的text字段内容: {text_decoded}")
|
||
|
||
|
||
# 获取响应的 JSON 数据
|
||
result = response.json()
|
||
with open(RESULT_FILEPATH, "w") as f:
|
||
json.dump(result, f, indent=4, ensure_ascii=False)
|
||
print(f"结果已成功写入 {RESULT_FILEPATH}")
|
||
|
||
submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config")
|
||
result_filepath = os.getenv("RESULT_FILEPATH", "./out/result")
|
||
bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase")
|
||
#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl")
|
||
|
||
from typing import Any, Dict, List
|
||
|
||
def result2file(
|
||
result: Dict[str, Any],
|
||
detail_cases: List[Dict[str, Any]] = None
|
||
):
|
||
assert result_filepath is not None
|
||
assert bad_cases_filepath is not None
|
||
#assert detailed_cases_filepath is not None
|
||
|
||
if result is not None:
|
||
with open(result_filepath, "w") as f:
|
||
json.dump(result, f, indent=4, ensure_ascii=False)
|
||
#if LOCAL_TEST:
|
||
# logger.info(f'result:\n {json.dumps(result, indent=4)}')
|
||
"""
|
||
if detail_cases is not None:
|
||
with open(detailed_cases_filepath, "w") as f:
|
||
json.dump(detail_cases, f, indent=4, ensure_ascii=False)
|
||
if LOCAL_TEST:
|
||
logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}')
|
||
"""
|
||
|
||
|
||
def test_image_prediction(sut_url, image_path):
|
||
"""发送单张图片到服务端预测"""
|
||
url = f"{sut_url}/v1/private/s782b4996"
|
||
|
||
try:
|
||
with open(image_path, 'rb') as f:
|
||
files = {'image': f}
|
||
response = requests.post(url, files=files, timeout=30)
|
||
|
||
result = response.json()
|
||
if result.get('status') != 'success':
|
||
return None, f"服务端错误: {result.get('message')}"
|
||
|
||
return result, None
|
||
except Exception as e:
|
||
return None, f"请求错误: {str(e)}"
|
||
|
||
|
||
|
||
import random
|
||
import time
|
||
#from tqdm import tqdm
|
||
import os
|
||
import requests
|
||
|
||
if __name__ == '__main__':
|
||
|
||
print(f"\n===== main开始请求接口 ===============================================")
|
||
# 1. 首先测试服务健康检查
|
||
|
||
print(f"\n===== 服务健康检查 ===================================================")
|
||
response = requests.get(f"{sut_url}/health")
|
||
print(response.status_code, response.text)
|
||
|
||
|
||
|
||
###############################################################################################
|
||
dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录
|
||
samples_per_class = 3 # 每个类别抽取的样本数
|
||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式
|
||
|
||
# 结果统计变量
|
||
total_samples = 0
|
||
#correct_predictions = 0
|
||
|
||
# GPU统计
|
||
gpu_true_positives = 0
|
||
gpu_false_positives = 0
|
||
gpu_false_negatives = 0
|
||
gpu_total_processing_time = 0.0
|
||
|
||
# CPU统计
|
||
cpu_true_positives = 0
|
||
cpu_false_positives = 0
|
||
cpu_false_negatives = 0
|
||
cpu_total_processing_time = 0.0
|
||
|
||
|
||
|
||
# 遍历所有类别文件夹
|
||
for folder_name in os.listdir(dataset_root):
|
||
folder_path = os.path.join(dataset_root, folder_name)
|
||
|
||
# 跳过非文件夹的项目
|
||
if not os.path.isdir(folder_path):
|
||
continue
|
||
|
||
# 提取类别名(从"序号.name"格式中提取name部分)
|
||
try:
|
||
class_name = folder_name.split('.', 1)[1].strip().lower()
|
||
except IndexError:
|
||
print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹")
|
||
continue
|
||
|
||
# 获取文件夹中所有图片
|
||
image_files = []
|
||
for file in os.listdir(folder_path):
|
||
file_path = os.path.join(folder_path, file)
|
||
if os.path.isfile(file_path) and file.lower().endswith(image_extensions):
|
||
image_files.append(file_path)
|
||
|
||
# 随机抽取指定数量的图片(如果不足则取全部)
|
||
selected_images = random.sample(
|
||
image_files,
|
||
min(samples_per_class, len(image_files))
|
||
)
|
||
|
||
for img_path in selected_images:
|
||
total_samples += 1
|
||
|
||
# 获取预测结果
|
||
prediction, error = test_image_prediction(sut_url, img_path)
|
||
|
||
# 打印test_image_prediction返回的结果
|
||
print(f"test_image_prediction返回的prediction: {prediction}")
|
||
print(f"test_image_prediction返回的error: {error}")
|
||
|
||
if error:
|
||
print(f"处理图片 {img_path} 失败: {error}")
|
||
continue
|
||
|
||
|
||
|
||
# 解析GPU预测结果
|
||
gpu_pred = prediction.get('cuda_prediction', {})
|
||
gpu_pred_class = gpu_pred.get('class_name', '').lower()
|
||
gpu_processing_time = gpu_pred.get('processing_time', 0.0)
|
||
|
||
# 解析CPU预测结果
|
||
cpu_pred = prediction.get('cpu_prediction', {})
|
||
cpu_pred_class = cpu_pred.get('class_name', '').lower()
|
||
cpu_processing_time = cpu_pred.get('processing_time', 0.0)
|
||
|
||
# 判断GPU预测是否正确
|
||
gpu_is_correct = class_name in gpu_pred_class
|
||
if gpu_is_correct:
|
||
gpu_true_positives += 1
|
||
else:
|
||
gpu_false_positives += 1
|
||
gpu_false_negatives += 1
|
||
|
||
# 判断CPU预测是否正确
|
||
cpu_is_correct = class_name in cpu_pred_class
|
||
if cpu_is_correct:
|
||
cpu_true_positives += 1
|
||
else:
|
||
cpu_false_positives += 1
|
||
cpu_false_negatives += 1
|
||
|
||
# 累加处理时间
|
||
gpu_total_processing_time += gpu_processing_time
|
||
cpu_total_processing_time += cpu_processing_time
|
||
|
||
# 打印详细结果
|
||
print(f"图片: {os.path.basename(img_path)} | 真实: {class_name}")
|
||
print(f"GPU预测: {gpu_pred_class} | {'正确' if gpu_is_correct else '错误'} | 耗时: {gpu_processing_time:.6f}s")
|
||
print(f"CPU预测: {cpu_pred_class} | {'正确' if cpu_is_correct else '错误'} | 耗时: {cpu_processing_time:.6f}s")
|
||
print("-" * 50)
|
||
|
||
|
||
# 初始化结果字典
|
||
result = {
|
||
# GPU指标
|
||
"gpu_accuracy": 0.0,
|
||
"gpu_recall": 0.0,
|
||
"gpu_running_time": round(gpu_total_processing_time, 6),
|
||
"gpu_throughput": 0.0,
|
||
|
||
# CPU指标
|
||
"cpu_accuracy": 0.0,
|
||
"cpu_recall": 0.0,
|
||
"cpu_running_time": round(cpu_total_processing_time, 6),
|
||
"cpu_throughput": 0.0
|
||
}
|
||
|
||
# 计算GPU指标
|
||
gpu_accuracy = gpu_true_positives / total_samples * 100
|
||
gpu_recall_denominator = gpu_true_positives + gpu_false_negatives
|
||
gpu_recall = gpu_true_positives / gpu_recall_denominator * 100 if gpu_recall_denominator > 0 else 0
|
||
gpu_throughput = total_samples / gpu_total_processing_time if gpu_total_processing_time > 1e-6 else 0
|
||
|
||
# 计算CPU指标
|
||
cpu_accuracy = cpu_true_positives / total_samples * 100
|
||
cpu_recall_denominator = cpu_true_positives + cpu_false_negatives
|
||
cpu_recall = cpu_true_positives / cpu_recall_denominator * 100 if cpu_recall_denominator > 0 else 0
|
||
cpu_throughput = total_samples / cpu_total_processing_time if cpu_total_processing_time > 1e-6 else 0
|
||
|
||
# 更新结果字典
|
||
result.update({
|
||
"gpu_accuracy": round(gpu_accuracy, 6),
|
||
"gpu_recall": round(gpu_recall, 6),
|
||
"gpu_throughput": round(gpu_throughput, 6),
|
||
|
||
"cpu_accuracy": round(cpu_accuracy, 6),
|
||
"cpu_recall": round(cpu_recall, 6),
|
||
"cpu_throughput": round(cpu_throughput, 6)
|
||
})
|
||
|
||
|
||
# 打印最终统计结果
|
||
print("\n" + "="*50)
|
||
print(f"总样本数: {total_samples}")
|
||
print("\nGPU指标:")
|
||
print(f"准确率: {result['gpu_accuracy']:.4f}%")
|
||
print(f"召回率: {result['gpu_recall']:.4f}%")
|
||
print(f"总运行时间: {result['gpu_running_time']:.6f}s")
|
||
print(f"吞吐量: {result['gpu_throughput']:.2f}张/秒")
|
||
|
||
print("\nCPU指标:")
|
||
print(f"准确率: {result['cpu_accuracy']:.4f}%")
|
||
print(f"召回率: {result['cpu_recall']:.4f}%")
|
||
print(f"总运行时间: {result['cpu_running_time']:.6f}s")
|
||
print(f"吞吐量: {result['cpu_throughput']:.2f}张/秒")
|
||
print("="*50)
|
||
|
||
|
||
#result = {}
|
||
#result['accuracy_1_1'] = 3
|
||
result2file(result)
|
||
|
||
if abs(gpu_accuracy - cpu_accuracy) > 3:
|
||
log.error(f"gpu与cpu准确率差别超过3%,模型结果不正确")
|
||
change_product_unavailable()
|
||
|
||
|
||
exit_code = 0
|
||
|
||
|