import json import os import sys import time import tempfile import zipfile import threading from collections import defaultdict from typing import Dict, List import yaml from pydantic import ValidationError from schemas.dataset import QueryData from utils.client_callback import ClientCallback, EvaluateResult, StopException from utils.logger import log from utils.service import register_sut from utils.update_submit import change_product_available from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml from utils.leaderboard import change_product_unavailable lck = threading.Lock() # Environment variables by leaderboard DATASET_FILEPATH = os.environ["DATASET_FILEPATH"] RESULT_FILEPATH = os.environ["RESULT_FILEPATH"] DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"] SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"] BENCHMARK_NAME = os.environ["BENCHMARK_NAME"] TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1)) THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8)) log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}") workspace_path = "/tmp/workspace" # Environment variables by kubernetes MY_POD_IP = os.environ["MY_POD_IP"] # constants RESOURCE_NAME = BENCHMARK_NAME # Environment variables by judge_flow_config LANG = os.getenv("lang") SUT_CPU = os.getenv("SUT_CPU", "2") SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi") SUT_VGPU = os.getenv("SUT_VGPU", "1") #SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU))) #SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU))) SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100") RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu") assert RESOURCE_TYPE in [ "cpu", "vgpu", ], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu" unzip_dir(DATASET_FILEPATH, workspace_path) def get_sut_url_kubernetes(): with open(SUBMIT_CONFIG_FILEPATH, "r") as f: submit_config = yaml.safe_load(f) assert isinstance(submit_config, dict) submit_config.setdefault("values", {}) submit_config["values"]["containers"] = [ { "name": "corex-container", "image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像 "command": ["sleep"], # 替换为你的模型启动命令,使用python解释器 "args": ["3600"], # 替换为你的模型参数,运行我的推理脚本 # 添加存储卷挂载 #"volumeMounts": [ # { # "name": "model-volume", # "mountPath": "/model" # 挂载到/model目录 # } #] } ] """ # 添加存储卷配置 submit_config["values"]["volumes"] = [ { "name": "model-volume", "persistentVolumeClaim": { "claimName": "sid-model-pvc" # 使用已有的PVC } } ] """ """ # Inject specified cpu and memory resource = { "cpu": SUT_CPU, "memory": SUT_MEMORY, } """ submit_config["values"]["resources"] = { "requests":{}, "limits": {}, } limits = submit_config["values"]["resources"]["limits"] requests = submit_config["values"]["resources"]["requests"] """ # ########## 关键修改:替换为iluvatar GPU配置 ########## if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU # 替换nvidia资源键为iluvatar.ai/gpu vgpu_resource = { "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 # 若需要其他资源(如显存),按你的K8s配置补充,例如: # "iluvatar.ai/gpumem": SUT_VGPU_MEM, } limits.update(vgpu_resource) requests.update(vgpu_resource) # 节点选择器:替换为你的accelerator标签 submit_config["values"]["nodeSelector"] = { "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 } # 容忍度:替换为你的tolerations配置 submit_config["values"]["tolerations"] = [ { "key": "hosttype", "operator": "Equal", "value": "iluvatar", "effect": "NoSchedule", } ] # ######################################### # 禁止CPU模式下使用GPU资源(保持原逻辑) else: if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests: log.error("禁止在CPU模式下使用GPU资源") sys.exit(1) #gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键 #for key in gpukeys: # if key in limits or key in requests: # log.error("禁止使用vgpu资源") # sys.exit(1) """ # 替换nvidia资源键为iluvatar.ai/gpu vgpu_resource = { "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 # 若需要其他资源(如显存),按你的K8s配置补充,例如: # "iluvatar.ai/gpumem": SUT_VGPU_MEM, } limits.update(vgpu_resource) requests.update(vgpu_resource) # 节点选择器:替换为你的accelerator标签 submit_config["values"]["nodeSelector"] = { "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 } # 容忍度:替换为你的tolerations配置 """ submit_config["values"]["tolerations"] = [ { "key": "hosttype", "operator": "Equal", "value": "iluvatar", "effect": "NoSchedule", }, { "key": "hosttype", "operator": "Equal", "value": "arm64", "effect": "NoSchedule", }, { "key": "hosttype", "operator": "Equal", "value": "myinit", "effect": "NoSchedule", }, { "key": "hosttype", "operator": "Equal", "value": "middleware", "effect": "NoSchedule", } ] """ """ { "key": "node-role.kubernetes.io/master", "operator": "Exists", "effect": "NoSchedule", }, { "key": "node.kubernetes.io/not-ready", "operator": "Exists", "effect": "NoExecute", "tolerationSeconds": 300 }, { "key": "node.kubernetes.io/unreachable", "operator": "Exists", "effect": "NoExecute", "tolerationSeconds": 300 } """ log.info(f"submit_config: {submit_config}") log.info(f"RESOURCE_NAME: {RESOURCE_NAME}") return register_sut(submit_config, RESOURCE_NAME).replace( "ws://", "http://" ) def get_sut_url(): return get_sut_url_kubernetes() #SUT_URL = get_sut_url() #os.environ["SUT_URL"] = SUT_URL """ def load_dataset( dataset_filepath: str, ) -> Dict[str, List[QueryData]]: dataset_path = tempfile.mkdtemp() with zipfile.ZipFile(dataset_filepath) as zf: zf.extractall(dataset_path) basename = os.path.basename(dataset_filepath) datayaml = os.path.join(dataset_path, "data.yaml") if not os.path.exists(datayaml): sub_dataset_paths = os.listdir(dataset_path) dataset = {} for sub_dataset_path in sub_dataset_paths: sub_dataset = load_dataset( os.path.join(dataset_path, sub_dataset_path) ) for k, v in sub_dataset.items(): k = os.path.join(basename, k) dataset[k] = v return dataset with open(datayaml, "r") as f: data = yaml.safe_load(f) assert isinstance(data, dict) lang = LANG data_lang = data.get("global", {}).get("lang") if lang is None and data_lang is not None: if data_lang is not None: # 使用配置中的语言类型 lang = data_lang if lang is None and basename.startswith("asr.") and len(basename) == 4 + 2: # 数据集名称为asr.en 可以认为语言为en lang = basename[4:] if lang is None: log.error( "数据集错误 通过data.yaml中的 global.lang 或 数据集名称 asr.xx 指定语言类型" ) sys.exit(1) query_data = data.get("query_data", []) audio_size_map = {} for query in query_data: query["lang"] = lang query["file"] = os.path.join(dataset_path, query["file"]) audio_size_map[query["file"]] = os.path.getsize(query["file"]) # 根据音频大小排序 query_data = sorted( query_data, key=lambda x: audio_size_map[x["file"]], reverse=True ) valid_query_data = [] for i_query_data in query_data: try: valid_query_data.append(QueryData.model_validate(i_query_data)) except ValidationError: log.error("数据集错误 数据中query_data格式错误") sys.exit(1) return { basename: valid_query_data, } def merge_query_data(dataset: Dict[str, List[QueryData]]) -> List[QueryData]: query_datas = [] for query_data in dataset.values(): query_datas.extend(query_data) return query_datas def run_one_predict( client: ClientCallback, query_data: QueryData, task_id: str ) -> EvaluateResult: try: client.predict(None, query_data.file, query_data.duration, task_id) except StopException: sys.exit(1) client.finished.wait() if client.error is not None: sys.exit(1) client.app_on = False try: with lck: ret = client.evaluate(query_data) return ret except StopException: sys.exit(1) """ """ def predict_task( client: ClientCallback, task_id: int, query_data: QueryData, test_results: list ): log.info(f"Task-{task_id}开始评测") test_results[task_id] = run_one_predict(client, query_data, str(task_id)) def merge_concurrent_result(evaluate_results: List[EvaluateResult]) -> Dict: cer = 0.0 align_start = {} align_end = {} first_word_distance_sum = 0.0 last_word_distance_sum = 0.0 rtf = 0.0 first_receive_delay: float = 0.0 query_count: int = 0 voice_count: int = 0 pred_punctuation_num: int = 0 label_punctuation_num: int = 0 pred_sentence_punctuation_num: int = 0 label_setence_punctuation_num: int = 0 for evalute_result in evaluate_results: cer += evalute_result.cer for k, v in evalute_result.align_start.items(): align_start.setdefault(k, 0) align_start[k] += v for k, v in evalute_result.align_end.items(): align_end.setdefault(k, 0) align_end[k] += v first_word_distance_sum += evalute_result.first_word_distance_sum last_word_distance_sum += evalute_result.last_word_distance_sum rtf += evalute_result.rtf first_receive_delay += evalute_result.first_receive_delay query_count += evalute_result.query_count voice_count += evalute_result.voice_count pred_punctuation_num += evalute_result.pred_punctuation_num label_punctuation_num += evalute_result.label_punctuation_num pred_sentence_punctuation_num += ( evalute_result.pred_sentence_punctuation_num ) label_setence_punctuation_num += ( evalute_result.label_setence_punctuation_num ) lens = len(evaluate_results) cer /= lens for k, v in align_start.items(): align_start[k] /= voice_count for k, v in align_end.items(): align_end[k] /= voice_count first_word_distance = first_word_distance_sum / voice_count last_word_distance = last_word_distance_sum / voice_count rtf /= lens first_receive_delay /= lens json_result = { "one_minus_cer": 1 - cer, "first_word_distance_mean": first_word_distance, "last_word_distance_mean": last_word_distance, "query_count": query_count // lens, "voice_count": voice_count // lens, "rtf": rtf, "first_receive_delay": first_receive_delay, "punctuation_ratio": ( pred_punctuation_num / label_punctuation_num if label_punctuation_num > 0 else 1.0 ), "sentence_punctuation_ratio": ( pred_sentence_punctuation_num / label_setence_punctuation_num if label_setence_punctuation_num > 0 else 1.0 ), } for k, v in align_start.items(): json_result["start_word_%dms_ratio" % k] = v for k, v in align_end.items(): json_result["end_word_%dms_ratio" % k] = v return json_result def merge_result(result: Dict[str, List[EvaluateResult]]) -> Dict: json_result = {} for lang, evaluate_results in result.items(): if len(evaluate_results) == 0: continue cer = 0.0 align_start = {} align_end = {} first_word_distance_sum = 0.0 last_word_distance_sum = 0.0 rtf = 0.0 first_receive_delay: float = 0.0 query_count: int = 0 voice_count: int = 0 pred_punctuation_num: int = 0 label_punctuation_num: int = 0 pred_sentence_punctuation_num: int = 0 label_setence_punctuation_num: int = 0 for evalute_result in evaluate_results: cer += evalute_result.cer for k, v in evalute_result.align_start.items(): align_start.setdefault(k, 0) align_start[k] += v for k, v in evalute_result.align_end.items(): align_end.setdefault(k, 0) align_end[k] += v first_word_distance_sum += evalute_result.first_word_distance_sum last_word_distance_sum += evalute_result.last_word_distance_sum rtf += evalute_result.rtf first_receive_delay += evalute_result.first_receive_delay query_count += evalute_result.query_count voice_count += evalute_result.voice_count pred_punctuation_num += evalute_result.pred_punctuation_num label_punctuation_num += evalute_result.label_punctuation_num pred_sentence_punctuation_num += ( evalute_result.pred_sentence_punctuation_num ) label_setence_punctuation_num += ( evalute_result.label_setence_punctuation_num ) lens = len(evaluate_results) cer /= lens for k, v in align_start.items(): align_start[k] /= voice_count for k, v in align_end.items(): align_end[k] /= voice_count first_word_distance = first_word_distance_sum / voice_count last_word_distance = last_word_distance_sum / voice_count rtf /= lens first_receive_delay /= lens lang_result = { "one_minus_cer": 1 - cer, "first_word_distance_mean": first_word_distance, "last_word_distance_mean": last_word_distance, "query_count": 1, "voice_count": voice_count, "rtf": rtf, "first_receive_delay": first_receive_delay, "punctuation_ratio": ( pred_punctuation_num / label_punctuation_num if label_punctuation_num > 0 else 1.0 ), "sentence_punctuation_ratio": ( pred_sentence_punctuation_num / label_setence_punctuation_num if label_setence_punctuation_num > 0 else 1.0 ), } for k, v in align_start.items(): lang_result["start_word_%dms_ratio" % k] = v for k, v in align_end.items(): lang_result["end_word_%dms_ratio" % k] = v if lang == "": json_result.update(lang_result) else: json_result[lang] = lang_result return json_result """ """ def main(): log.info(f'{TEST_CONCURRENCY=}, {THRESHOLD_OMCER=}') dataset = load_dataset(DATASET_FILEPATH) query_datas = merge_query_data(dataset) #获取 ASR 服务 URL(通常从 Kubernetes 配置) sut_url = get_sut_url() #创建多个客户端实例(每个客户端监听不同端口,如 80、81、82...) port_base = 80 clients = [ClientCallback(sut_url, port_base + i) for i in range(TEST_CONCURRENCY)] #准备测试数据与线程 detail_cases = [] # we use the same test data for all requests query_data = query_datas[0] test_results = [None] * len(clients) test_threads = [threading.Thread(target=predict_task, args=(client, task_id, query_data, test_results)) for task_id, client in enumerate(clients)] #启动并发测试,启动线程并间隔10秒,设置超时时间为1小时 for t in test_threads: t.start() time.sleep(10) [t.join(timeout=3600) for t in test_threads] #合并结果与评估 final_result = merge_concurrent_result(test_results) product_avaiable = all([c.product_avaiable for c in clients]) final_result['concurrent_req'] = TEST_CONCURRENCY if final_result['one_minus_cer'] < THRESHOLD_OMCER: product_avaiable = False if not product_avaiable: final_result['success'] = False change_product_available() else: final_result['success'] = True #保存结果, log.info( "指标结果为: %s", json.dumps(final_result, indent=2, ensure_ascii=False) ) time.sleep(120) #打印并保存最终结果到文件 with open(RESULT_FILEPATH, "w") as f: json.dump(final_result, f, indent=2, ensure_ascii=False) #保存详细测试用例结果 with open(DETAILED_CASES_FILEPATH, "w") as f: json.dump(detail_cases, f, indent=2, ensure_ascii=False) """ ############################################################################# import requests import base64 def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): """ 生成请求的body :param apiname :param APPId: Appid :param file_name: 文件路径 :return: """ if apiname == 'createFeature': with open(file_path, "rb") as f: audioBytes = f.read() body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "createFeature", "groupId": "test_voiceprint_e", "featureId": featureId, "featureInfo": featureInfo, "createFeatureRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } }, "payload": { "resource": { "encoding": "lame", "sample_rate": 16000, "channels": 1, "bit_depth": 16, "status": 3, "audio": str(base64.b64encode(audioBytes), 'UTF-8') } } } elif apiname == 'createGroup': body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "createGroup", "groupId": "test_voiceprint_e", "groupName": "vip_user", "groupInfo": "store_vip_user_voiceprint", "createGroupRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } } } elif apiname == 'deleteFeature': body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "deleteFeature", "groupId": "iFLYTEK_examples_groupId", "featureId": "iFLYTEK_examples_featureId", "deleteFeatureRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } } } elif apiname == 'queryFeatureList': body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "queryFeatureList", "groupId": "user_voiceprint_2", "queryFeatureListRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } } } elif apiname == 'searchFea': with open(file_path, "rb") as f: audioBytes = f.read() body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "searchFea", "groupId": "test_voiceprint_e", "topK": 1, "searchFeaRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } }, "payload": { "resource": { "encoding": "lame", "sample_rate": 16000, "channels": 1, "bit_depth": 16, "status": 3, "audio": str(base64.b64encode(audioBytes), 'UTF-8') } } } elif apiname == 'searchScoreFea': with open(file_path, "rb") as f: audioBytes = f.read() body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "searchScoreFea", "groupId": "test_voiceprint_e", "dstFeatureId": dstFeatureId, "searchScoreFeaRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } }, "payload": { "resource": { "encoding": "lame", "sample_rate": 16000, "channels": 1, "bit_depth": 16, "status": 3, "audio": str(base64.b64encode(audioBytes), 'UTF-8') } } } elif apiname == 'updateFeature': with open(file_path, "rb") as f: audioBytes = f.read() body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "updateFeature", "groupId": "iFLYTEK_examples_groupId", "featureId": "iFLYTEK_examples_featureId", "featureInfo": "iFLYTEK_examples_featureInfo_update", "updateFeatureRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } }, "payload": { "resource": { "encoding": "lame", "sample_rate": 16000, "channels": 1, "bit_depth": 16, "status": 3, "audio": str(base64.b64encode(audioBytes), 'UTF-8') } } } elif apiname == 'deleteGroup': body = { "header": { "app_id": APPId, "status": 3 }, "parameter": { "s782b4996": { "func": "deleteGroup", "groupId": "iFLYTEK_examples_groupId", "deleteGroupRes": { "encoding": "utf8", "compress": "raw", "format": "json" } } } } else: raise Exception( "输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查") return body log.info(f"开始请求获取到SUT服务URL") # 获取SUT服务URL sut_url = get_sut_url() print(f"获取到的SUT_URL: {sut_url}") # 调试输出 log.info(f"获取到SUT服务URL: {sut_url}") from urllib.parse import urlparse # 全局变量 text_decoded = None ###################################新增新增################################ def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): """ 开始请求 :param APPId: APPID :param file_path: body里的文件路径 :return: """ global text_decoded body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId) #request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996' #request_url = 'https://sut:80/sid/v1/private/s782b4996' #headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId} parsed_url = urlparse(sut_url) headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId} # 1. 首先测试服务健康检查 response = requests.get(f"{sut_url}/health") print(response.status_code, response.text) # 请求头 headers = {"Content-Type": "application/json"} # 请求体(可指定限制处理的图片数量) body = {"limit": 20 } # 可选参数,限制处理的图片总数 # 发送POST请求 response = requests.post( f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers ) # 解析响应结果 if response.status_code == 200: result = response.json() print("预测评估结果:") print(f"准确率: {result['metrics']['accuracy']}%") print(f"平均召回率: {result['metrics']['average_recall']}%") print(f"处理图片总数: {result['metrics']['total_images']}") else: print(f"请求失败,状态码: {response.status_code}") print(f"错误信息: {response.text}") # 添加基本认证信息 auth = ('llm', 'Rmf4#LcG(iFZrjU;2J') #response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth) #response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth) #response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth) """ response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers) """ #print("HTTP状态码:", response.status_code) #print("原始响应内容:", response.text) # 先打印原始内容 #print(f"请求URL: {sut_url + '/v1/private/s782b4996'}") #print(f"请求headers: {headers}") #print(f"请求body: {body}") #tempResult = json.loads(response.content.decode('utf-8')) #print(tempResult) """ # 对text字段进行Base64解码 if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']: text_encoded = tempResult['payload']['updateFeatureRes']['text'] text_decoded = base64.b64decode(text_encoded).decode('utf-8') print(f"Base64解码后的text字段内容: {text_decoded}") """ #text_encoded = tempResult['payload']['updateFeatureRes']['text'] #text_decoded = base64.b64decode(text_encoded).decode('utf-8') #print(f"Base64解码后的text字段内容: {text_decoded}") # 获取响应的 JSON 数据 result = response.json() with open(RESULT_FILEPATH, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) print(f"结果已成功写入 {RESULT_FILEPATH}") submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config") result_filepath = os.getenv("RESULT_FILEPATH", "./out/result") bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase") #detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl") from typing import Any, Dict, List def result2file( result: Dict[str, Any], detail_cases: List[Dict[str, Any]] = None ): assert result_filepath is not None assert bad_cases_filepath is not None #assert detailed_cases_filepath is not None if result is not None: with open(result_filepath, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) #if LOCAL_TEST: # logger.info(f'result:\n {json.dumps(result, indent=4)}') """ if detail_cases is not None: with open(detailed_cases_filepath, "w") as f: json.dump(detail_cases, f, indent=4, ensure_ascii=False) if LOCAL_TEST: logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}') """ def test_image_prediction(sut_url, image_path): """发送单张图片到服务端预测""" url = f"{sut_url}/v1/private/s782b4996" try: with open(image_path, 'rb') as f: files = {'image': f} response = requests.post(url, files=files, timeout=30) result = response.json() if result.get('status') != 'success': return None, f"服务端错误: {result.get('message')}" return result.get('top_prediction'), None except Exception as e: return None, f"请求错误: {str(e)}" import random import time #from tqdm import tqdm import os import requests if __name__ == '__main__': print(f"\n===== main开始请求接口 ===============================================") # 1. 首先测试服务健康检查 print(f"\n===== 服务健康检查 ===================================================") response = requests.get(f"{sut_url}/health") print(response.status_code, response.text) """ # 本地图片路径和真实标签(根据实际情况修改) image_path = "/path/to/your/test_image.jpg" true_label = "cat" # 图片的真实标签 """ """ # 请求头 headers = {"Content-Type": "application/json"} # 请求体(可指定限制处理的图片数量) body = {"limit": 20 } # 可选参数,限制处理的图片总数 # 发送POST请求 response = requests.post( f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers ) """ """ # 读取图片文件 with open(image_path, 'rb') as f: files = {'image': f} # 发送POST请求 response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files) # 解析响应结果 if response.status_code == 200: result = response.json() print("预测评估结果:") print(f"准确率: {result['metrics']['accuracy']}%") print(f"平均召回率: {result['metrics']['average_recall']}%") print(f"处理图片总数: {result['metrics']['total_images']}") else: print(f"请求失败,状态码: {response.status_code}") print(f"错误信息: {response.text}") """ ############################################################################################### dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录 samples_per_class = 3 # 每个类别抽取的样本数 image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式 # 结果统计变量 #total_samples = 0 correct_predictions = 0 # 结果统计变量 total_samples = 0 true_positives = 0 false_positives = 0 false_negatives = 0 total_processing_time = 0.0 # 总处理时间(秒) """ # 遍历所有类别文件夹 for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"): folder_path = os.path.join(dataset_root, folder_name) # 提取类别名(从"序号.name"格式中提取name部分) class_name = folder_name.split('.', 1)[1].strip().lower() # 获取文件夹中所有图片 image_files = [] for file in os.listdir(folder_path): if file.lower().endswith(image_extensions): image_files.append(os.path.join(folder_path, file)) # 随机抽取指定数量的图片(如果不足则取全部) selected_images = random.sample( image_files, min(samples_per_class, len(image_files)) ) # 处理选中的图片 for img_path in selected_images: total_count += 1 # 发送预测请求 prediction, error = test_image_prediction(sut_url, img_path) if error: print(f"处理图片 {img_path} 失败: {error}") continue # 解析预测结果 pred_class = prediction.get('class_name', '').lower() confidence = prediction.get('confidence', 0) # 判断是否预测正确(真实类别是否在预测类别中) if class_name in pred_class: correct_predictions += 1 # 可选:打印详细结果 print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") """ # 遍历所有类别文件夹 for folder_name in os.listdir(dataset_root): folder_path = os.path.join(dataset_root, folder_name) # 跳过非文件夹的项目 if not os.path.isdir(folder_path): continue # 提取类别名(从"序号.name"格式中提取name部分) try: class_name = folder_name.split('.', 1)[1].strip().lower() except IndexError: print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹") continue # 获取文件夹中所有图片 image_files = [] for file in os.listdir(folder_path): file_path = os.path.join(folder_path, file) if os.path.isfile(file_path) and file.lower().endswith(image_extensions): image_files.append(file_path) # 随机抽取指定数量的图片(如果不足则取全部) selected_images = random.sample( image_files, min(samples_per_class, len(image_files)) ) # 处理该文件夹中的所有图片 for img_path in selected_images: total_samples += 1 start_time = time.time() # 记录开始时间 # 发送预测请求 prediction, error = test_image_prediction(sut_url, img_path) # 计算单张图片处理时间(包括网络请求和模型预测) processing_time = time.time() - start_time total_processing_time += processing_time if error: print(f"处理图片 {img_path} 失败: {error}") # 处理失败的样本视为预测错误 false_negatives += 1 continue # 解析预测结果 pred_class = prediction.get('class_name', '').lower() confidence = prediction.get('confidence', 0) # 判断是否预测正确(真实类别是否在预测类别中,不分大小写) is_correct = class_name in pred_class # 更新统计指标 if is_correct: true_positives += 1 else: false_positives += 1 false_negatives += 1 # 打印详细结果(可选) print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") """ # 计算整体指标(在单标签场景下,准确率=召回率) if total_samples == 0: overall_accuracy = 0.0 overall_recall = 0.0 else: overall_accuracy = correct_predictions / total_samples overall_recall = correct_predictions / total_samples # 整体召回率 # 输出统计结果 print("\n" + "="*50) print(f"测试总结:") print(f"总测试样本数: {total_samples}") print(f"正确预测样本数: {correct_predictions}") print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})") print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})") print("="*50) """ # 初始化结果字典 result = { "total_processing_time": round(total_processing_time, 6), "throughput": 0.0, "accuracy": 0.0, "recall": 0.0 } # 计算评估指标 if total_samples == 0: print("没有找到任何图片样本") # 准确率 = 正确预测的样本数 / 总预测样本数 accuracy = true_positives / total_samples * 100 if total_samples > 0 else 0 # 召回率 = 正确预测的样本数 / (正确预测的样本数 + 未正确预测的正样本数) recall_denominator = true_positives + false_negatives recall = true_positives / recall_denominator * 100 if recall_denominator > 0 else 0 # 处理速度计算(每秒钟处理的图片张数) # 避免除以0(当总时间极短时) throughput = total_samples / total_processing_time if total_processing_time > 1e-6 else 0 # 更新结果字典 result.update({ "throughput": round(throughput, 6), "accuracy": round(accuracy, 6), "recall": round(recall, 6) }) # 打印最终统计结果 print("\n" + "="*50) print(f"总样本数: {total_samples}") print(f"总处理时间: {total_processing_time:.4f}秒") print(f"处理速度: {throughput:.2f}张/秒") # 新增:每秒钟处理的图片张数 print(f"正确预测: {true_positives}") print(f"错误预测: {total_samples - true_positives}") print(f"准确率: {accuracy:.4f} ({true_positives}/{total_samples})") print(f"召回率: {recall:.4f} ({true_positives}/{recall_denominator})") print("="*50) #result = {} #result['accuracy_1_1'] = 3 result2file(result) """ if result['accuracy_1_1'] < 0.9: log.error(f"1:1正确率未达到90%, 视为产品不可用") change_product_unavailable() if result['accuracy_1_N'] < 1: log.error(f"1:N正确率未达到100%, 视为产品不可用") change_product_unavailable() if result['1_1_latency'] > 0.5: log.error(f"1:1平均latency超过0.5s, 视为产品不可用") change_product_unavailable() if result['1_N_latency'] > 0.5: log.error(f"1:N平均latency超过0.5s, 视为产品不可用") change_product_unavailable() if result['enroll_latency'] > 1: log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用") change_product_unavailable() """ exit_code = 0