diff --git a/run_callback.py b/run_callback.py index 0a64919..1bde6d2 100644 --- a/run_callback.py +++ b/run_callback.py @@ -111,46 +111,6 @@ def get_sut_url_kubernetes(): requests = submit_config["values"]["resources"]["requests"] - """ - # ########## 关键修改:替换为iluvatar GPU配置 ########## - if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU - # 替换nvidia资源键为iluvatar.ai/gpu - vgpu_resource = { - "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 - # 若需要其他资源(如显存),按你的K8s配置补充,例如: - # "iluvatar.ai/gpumem": SUT_VGPU_MEM, - } - limits.update(vgpu_resource) - requests.update(vgpu_resource) - # 节点选择器:替换为你的accelerator标签 - submit_config["values"]["nodeSelector"] = { - "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 - } - # 容忍度:替换为你的tolerations配置 - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - } - ] - # ######################################### - # 禁止CPU模式下使用GPU资源(保持原逻辑) - else: - if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests: - log.error("禁止在CPU模式下使用GPU资源") - sys.exit(1) - - - - #gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键 - #for key in gpukeys: - # if key in limits or key in requests: - # log.error("禁止使用vgpu资源") - # sys.exit(1) - - """ # 替换nvidia资源键为iluvatar.ai/gpu vgpu_resource = { @@ -165,54 +125,8 @@ def get_sut_url_kubernetes(): "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 } # 容忍度:替换为你的tolerations配置 - """ - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "arm64", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "myinit", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "middleware", - "effect": "NoSchedule", - } - - ] - """ - """ - { - "key": "node-role.kubernetes.io/master", - "operator": "Exists", - "effect": "NoSchedule", - }, - { - "key": "node.kubernetes.io/not-ready", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - }, - { - "key": "node.kubernetes.io/unreachable", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - } - """ + + log.info(f"submit_config: {submit_config}") @@ -629,46 +543,6 @@ if __name__ == '__main__': response = requests.get(f"{sut_url}/health") print(response.status_code, response.text) - """ - # 本地图片路径和真实标签(根据实际情况修改) - image_path = "/path/to/your/test_image.jpg" - true_label = "cat" # 图片的真实标签 - """ - - - """ - # 请求头 - headers = {"Content-Type": "application/json"} - # 请求体(可指定限制处理的图片数量) - body = {"limit": 20 } # 可选参数,限制处理的图片总数 - - # 发送POST请求 - response = requests.post( - f"{sut_url}/v1/private/s782b4996", - data=json.dumps(body), - headers=headers - ) - """ - - """ - # 读取图片文件 - with open(image_path, 'rb') as f: - files = {'image': f} - # 发送POST请求 - response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files) - - - # 解析响应结果 - if response.status_code == 200: - result = response.json() - print("预测评估结果:") - print(f"准确率: {result['metrics']['accuracy']}%") - print(f"平均召回率: {result['metrics']['average_recall']}%") - print(f"处理图片总数: {result['metrics']['total_images']}") - else: - print(f"请求失败,状态码: {response.status_code}") - print(f"错误信息: {response.text}") - """ ############################################################################################### @@ -692,49 +566,7 @@ if __name__ == '__main__': cpu_false_negatives = 0 cpu_total_processing_time = 0.0 - """ - # 遍历所有类别文件夹 - for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"): - folder_path = os.path.join(dataset_root, folder_name) - - - # 提取类别名(从"序号.name"格式中提取name部分) - class_name = folder_name.split('.', 1)[1].strip().lower() - - # 获取文件夹中所有图片 - image_files = [] - for file in os.listdir(folder_path): - if file.lower().endswith(image_extensions): - image_files.append(os.path.join(folder_path, file)) - - # 随机抽取指定数量的图片(如果不足则取全部) - selected_images = random.sample( - image_files, - min(samples_per_class, len(image_files)) - ) - - # 处理选中的图片 - for img_path in selected_images: - total_count += 1 - - # 发送预测请求 - prediction, error = test_image_prediction(sut_url, img_path) - if error: - print(f"处理图片 {img_path} 失败: {error}") - continue - - # 解析预测结果 - pred_class = prediction.get('class_name', '').lower() - confidence = prediction.get('confidence', 0) - - # 判断是否预测正确(真实类别是否在预测类别中) - if class_name in pred_class: - correct_predictions += 1 - - - # 可选:打印详细结果 - print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") - """ + # 遍历所有类别文件夹 for folder_name in os.listdir(dataset_root): @@ -815,25 +647,7 @@ if __name__ == '__main__': print(f"GPU预测: {gpu_pred_class} | {'正确' if gpu_is_correct else '错误'} | 耗时: {gpu_processing_time:.6f}s") print(f"CPU预测: {cpu_pred_class} | {'正确' if cpu_is_correct else '错误'} | 耗时: {cpu_processing_time:.6f}s") print("-" * 50) - - """ - # 计算整体指标(在单标签场景下,准确率=召回率) - if total_samples == 0: - overall_accuracy = 0.0 - overall_recall = 0.0 - else: - overall_accuracy = correct_predictions / total_samples - overall_recall = correct_predictions / total_samples # 整体召回率 - - # 输出统计结果 - print("\n" + "="*50) - print(f"测试总结:") - print(f"总测试样本数: {total_samples}") - print(f"正确预测样本数: {correct_predictions}") - print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})") - print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})") - print("="*50) - """ + # 初始化结果字典 result = { @@ -899,25 +713,7 @@ if __name__ == '__main__': log.error(f"gpu与cpu准确率差别超过3%,模型结果不正确") change_product_unavailable() - """ - if result['accuracy_1_1'] < 0.9: - log.error(f"1:1正确率未达到90%, 视为产品不可用") - change_product_unavailable() - - if result['accuracy_1_N'] < 1: - log.error(f"1:N正确率未达到100%, 视为产品不可用") - change_product_unavailable() - if result['1_1_latency'] > 0.5: - log.error(f"1:1平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['1_N_latency'] > 0.5: - log.error(f"1:N平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['enroll_latency'] > 1: - log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用") - change_product_unavailable() - """ exit_code = 0