From 9d18371bb709e23e3077ee498c8bbfc3fb4d874b Mon Sep 17 00:00:00 2001 From: zhousha <736730048@qq.com> Date: Tue, 12 Aug 2025 18:19:12 +0800 Subject: [PATCH] push to main --- .DS_Store | Bin 6148 -> 6148 bytes Dockerfile_bi100 | 12 + model_test_caltech_3.py | 215 ----- model_test_caltech_cpu1.py | 197 ----- model_test_caltech_http_1.py | 163 ---- model_test_caltech_http_3.py | 89 --- model_test_caltech_http_cuda.py | 80 -- run_callback_cuda.py | 1193 ---------------------------- run_callback_new.py | 1296 ------------------------------- 9 files changed, 12 insertions(+), 3233 deletions(-) create mode 100644 Dockerfile_bi100 delete mode 100644 model_test_caltech_3.py delete mode 100644 model_test_caltech_cpu1.py delete mode 100644 model_test_caltech_http_1.py delete mode 100644 model_test_caltech_http_3.py delete mode 100644 model_test_caltech_http_cuda.py delete mode 100644 run_callback_cuda.py delete mode 100644 run_callback_new.py diff --git a/.DS_Store b/.DS_Store index 2a2f59598f18ea868848d35804cc0483f4b28dbf..647509b8f484da8dbf518208e44a023635eb488d 100644 GIT binary patch delta 108 zcmZoMXfc=|#>CJ*u~2NHo}wr-0|Nsi1A_nqgDyidLk2@4LlHyC#6opO=E?TVD&jm0 z#SA47L556*Jce|Jc!q3-%*~<9n^-naVD4ku%+A5j0W@xNBlCCW$^0UQj0}@)c%(PS Ih^$}+0OwE`wg3PC delta 397 zcmZoMXfc=|#>B)qu~2NHo}wr#0|Nsi1A_nqLkUANLq0=1LncG+#*fPx>p?QS42cZI z3`GpO$g)7ObcWQV^5TM|octu9svSuMIhn;J1_sv{nV4Bv+1NSQIk-7ugER8WgG&-i zN{gKmi=siiko^3dBp5p}DJ(O!JYGP=IX|x?F*7f<2&^G9B^9V7CcHi~FD1X+DZex? zr5LO^7$U>L$-x;fAW>ayY+|aTU~FJitD{hDX=I?IU}9`mTg%BIs;qAv6rY`wo0s1O zbQ};c0zC)@yigiObpsi=yl4jWV@h#yP7=_~AoaM^0hN^n7v<&T=cNNxF;47J*?4sc k%Vu^Cehy$LZ~XY3c{0C 1: - print(f"使用 {torch.cuda.device_count()} 块 GPU 进行计算") - -class COCOImageClassifier: - def __init__(self, model_path: str, local_image_paths: list): - """初始化COCO图像分类器""" - self.processor = AutoImageProcessor.from_pretrained(model_path) - self.model = AutoModelForImageClassification.from_pretrained(model_path) - - # 将模型移动到设备 - self.model = self.model.to(device) - print(f"模型是否在 GPU 上: {next(self.model.parameters()).is_cuda}") # 添加调试信息 - - # 若有多块 GPU,使用 DataParallel - if torch.cuda.device_count() > 1: - self.model = torch.nn.DataParallel(self.model) - - self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label - self.local_image_paths = local_image_paths - - def predict_image_path(self, image_path: str, top_k: int = 5) -> dict: - """ - 预测本地图片文件对应的图片类别 - - Args: - image_path: 本地图片文件路径 - top_k: 返回置信度最高的前k个类别 - - Returns: - 包含预测结果的字典 - """ - try: - # 打开图片 - image = Image.open(image_path).convert("RGB") - - # 预处理 - inputs = self.processor(images=image, return_tensors="pt") - - # 将输入数据移动到设备 - inputs = inputs.to(device) - - # 模型推理 - with torch.no_grad(): - outputs = self.model(**inputs) - - # 获取预测结果 - logits = outputs.logits - probs = torch.nn.functional.softmax(logits, dim=1) - top_probs, top_indices = probs.topk(top_k, dim=1) - - # 整理结果 - predictions = [] - for i in range(top_k): - class_idx = top_indices[0, i].item() - confidence = top_probs[0, i].item() - predictions.append({ - "class_id": class_idx, - "class_name": self.id2label[class_idx], - "confidence": confidence - }) - - return { - "image_path": image_path, - "predictions": predictions - } - - except Exception as e: - print(f"处理图片文件 {image_path} 时出错: {e}") - return None - - def batch_predict(self, limit: int = 20, top_k: int = 5) -> list: - """ - 批量预测本地图片 - - Args: - limit: 限制处理的图片数量 - top_k: 返回置信度最高的前k个类别 - - Returns: - 包含所有预测结果的列表 - """ - results = [] - local_image_paths = self.local_image_paths[:limit] - - print(f"开始预测 {len(local_image_paths)} 张本地图片...") - start_time = time.time() # 记录开始时间 - for image_path in tqdm(local_image_paths): - result = self.predict_image_path(image_path, top_k) - if result: - results.append(result) - end_time = time.time() # 记录结束时间 - total_time = end_time - start_time # 计算总时间 - images_per_second = len(results) / total_time # 计算每秒处理的图片数量 - print(f"模型每秒可以处理 {images_per_second:.2f} 张图片") - return results - - def save_results(self, results: list, output_file: str = "caltech_predictions.json"): - """ - 保存预测结果到JSON文件 - - Args: - results: 预测结果列表 - output_file: 输出文件名 - """ - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(results, f, ensure_ascii=False, indent=2) - - print(f"结果已保存到 {output_file}") - -# 主程序 -if __name__ == "__main__": - # 替换为本地模型路径 - LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k" - - # 替换为Caltech 256数据集文件夹路径 - CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew" - - local_image_paths = [] - true_labels = {} - - # 遍历Caltech 256数据集中的每个文件夹 - for folder in os.listdir(CALTECH_256_PATH): - folder_path = os.path.join(CALTECH_256_PATH, folder) - if os.path.isdir(folder_path): - # 获取文件夹名称中的类别名称 - class_name = folder.split('.', 1)[1] - # 获取文件夹中的所有图片文件 - image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))] - # 随机选择3张图片 - selected_images = random.sample(image_files, min(3, len(image_files))) - for image_path in selected_images: - local_image_paths.append(image_path) - true_labels[image_path] = class_name - - # 创建分类器实例 - classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths) - - # 批量预测 - results = classifier.batch_predict(limit=len(local_image_paths), top_k=3) - - # 保存结果 - classifier.save_results(results) - - # 打印简要统计 - print(f"\n处理完成: 成功预测 {len(results)} 张图片") - if results: - print("\n示例预测结果:") - sample = results[0] - print(f"图片路径: {sample['image_path']}") - for i, pred in enumerate(sample['predictions'], 1): - print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})") - - correct_count = 0 - total_count = len(results) - - # 统计每个类别的实际样本数和正确预测数 - class_actual_count = {} - class_correct_count = {} - - for prediction in results: - image_path = prediction['image_path'] - top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence']) - predicted_class = top1_prediction['class_name'].lower() - true_class = true_labels.get(image_path).lower() - - # 统计每个类别的实际样本数 - if true_class not in class_actual_count: - class_actual_count[true_class] = 0 - class_actual_count[true_class] += 1 - - # 检查预测类别中的每个单词是否包含真实标签 - words = predicted_class.split() - for word in words: - if true_class in word: - correct_count += 1 - # 统计每个类别的正确预测数 - if true_class not in class_correct_count: - class_correct_count[true_class] = 0 - class_correct_count[true_class] += 1 - break - - accuracy = correct_count / total_count - print(f"\nAccuracy: {accuracy * 100:.2f}%") - - # 计算每个类别的召回率 - recall_per_class = {} - for class_name in class_actual_count: - if class_name in class_correct_count: - recall_per_class[class_name] = class_correct_count[class_name] / class_actual_count[class_name] - else: - recall_per_class[class_name] = 0 - - # 计算平均召回率 - average_recall = sum(recall_per_class.values()) / len(recall_per_class) - print(f"\nAverage Recall: {average_recall * 100:.2f}%") \ No newline at end of file diff --git a/model_test_caltech_cpu1.py b/model_test_caltech_cpu1.py deleted file mode 100644 index deed2bd..0000000 --- a/model_test_caltech_cpu1.py +++ /dev/null @@ -1,197 +0,0 @@ -import requests -import json -import torch -from PIL import Image -from io import BytesIO -from transformers import AutoImageProcessor, AutoModelForImageClassification -from tqdm import tqdm -import os -import random -import time - -# 强制使用CPU -device = torch.device("cpu") -print(f"当前使用的设备: {device}") - -class COCOImageClassifier: - def __init__(self, model_path: str, local_image_paths: list): - """初始化COCO图像分类器""" - self.processor = AutoImageProcessor.from_pretrained(model_path) - self.model = AutoModelForImageClassification.from_pretrained(model_path) - - # 将模型移动到CPU - self.model = self.model.to(device) - self.id2label = self.model.config.id2label - self.local_image_paths = local_image_paths - - def predict_image_path(self, image_path: str, top_k: int = 5) -> dict: - """ - 预测本地图片文件对应的图片类别 - - Args: - image_path: 本地图片文件路径 - top_k: 返回置信度最高的前k个类别 - - Returns: - 包含预测结果的字典 - """ - try: - # 打开图片 - image = Image.open(image_path).convert("RGB") - - # 预处理 - inputs = self.processor(images=image, return_tensors="pt") - - # 将输入数据移动到CPU - inputs = inputs.to(device) - - # 模型推理 - with torch.no_grad(): - outputs = self.model(**inputs) - - # 获取预测结果 - logits = outputs.logits - probs = torch.nn.functional.softmax(logits, dim=1) - top_probs, top_indices = probs.topk(top_k, dim=1) - - # 整理结果 - predictions = [] - for i in range(top_k): - class_idx = top_indices[0, i].item() - confidence = top_probs[0, i].item() - predictions.append({ - "class_id": class_idx, - "class_name": self.id2label[class_idx], - "confidence": confidence - }) - - return { - "image_path": image_path, - "predictions": predictions - } - - except Exception as e: - print(f"处理图片文件 {image_path} 时出错: {e}") - return None - - def batch_predict(self, limit: int = 20, top_k: int = 5) -> list: - """ - 批量预测本地图片 - - Args: - limit: 限制处理的图片数量 - top_k: 返回置信度最高的前k个类别 - - Returns: - 包含所有预测结果的列表 - """ - results = [] - local_image_paths = self.local_image_paths[:limit] - - print(f"开始预测 {len(local_image_paths)} 张本地图片...") - start_time = time.time() - for image_path in tqdm(local_image_paths): - result = self.predict_image_path(image_path, top_k) - if result: - results.append(result) - end_time = time.time() - - # 计算吞吐量 - throughput = len(results) / (end_time - start_time) - print(f"模型每秒可以处理 {throughput:.2f} 张图片") - - return results - - def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"): - """ - 保存预测结果到JSON文件 - - Args: - results: 预测结果列表 - output_file: 输出文件名 - """ - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(results, f, ensure_ascii=False, indent=2) - - print(f"结果已保存到 {output_file}") - -# 主程序 -if __name__ == "__main__": - # 替换为本地模型路径 - LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k" - - # 替换为Caltech 256数据集文件夹路径 New - CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew" - - local_image_paths = [] - true_labels = {} - - # 遍历Caltech 256数据集中的每个文件夹 - for folder in os.listdir(CALTECH_256_PATH): - folder_path = os.path.join(CALTECH_256_PATH, folder) - if os.path.isdir(folder_path): - # 获取文件夹名称中的类别名称 - class_name = folder.split('.', 1)[1] - # 获取文件夹中的所有图片文件 - image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))] - # 随机选择3张图片 - selected_images = random.sample(image_files, min(3, len(image_files))) - for image_path in selected_images: - local_image_paths.append(image_path) - true_labels[image_path] = class_name - - # 创建分类器实例 - classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths) - - # 批量预测 - results = classifier.batch_predict(limit=len(local_image_paths), top_k=3) - - # 保存结果 - classifier.save_results(results) - - # 打印简要统计 - print(f"\n处理完成: 成功预测 {len(results)} 张图片") - if results: - print("\n示例预测结果:") - sample = results[0] - print(f"图片路径: {sample['image_path']}") - for i, pred in enumerate(sample['predictions'], 1): - print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})") - - correct_count = 0 - total_count = len(results) - class_true_positives = {} - class_false_negatives = {} - - for prediction in results: - image_path = prediction['image_path'] - top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence']) - predicted_class = top1_prediction['class_name'].lower() - true_class = true_labels.get(image_path).lower() - - if true_class not in class_true_positives: - class_true_positives[true_class] = 0 - class_false_negatives[true_class] = 0 - - # 检查预测类别中的每个单词是否包含真实标签 - words = predicted_class.split() - for word in words: - if true_class in word: - correct_count += 1 - class_true_positives[true_class] += 1 - break - else: - class_false_negatives[true_class] += 1 - - accuracy = correct_count / total_count - print(f"\nAccuracy: {accuracy * 100:.2f}%") - - # 计算召回率 - total_true_positives = 0 - total_false_negatives = 0 - for class_name in class_true_positives: - total_true_positives += class_true_positives[class_name] - total_false_negatives += class_false_negatives[class_name] - - recall = total_true_positives / (total_true_positives + total_false_negatives) - print(f"Recall: {recall * 100:.2f}%") \ No newline at end of file diff --git a/model_test_caltech_http_1.py b/model_test_caltech_http_1.py deleted file mode 100644 index 12f6d3f..0000000 --- a/model_test_caltech_http_1.py +++ /dev/null @@ -1,163 +0,0 @@ -import requests -import json -import torch -from PIL import Image -from io import BytesIO -from transformers import AutoImageProcessor, AutoModelForImageClassification -from tqdm import tqdm -import os -import random -import time -from flask import Flask, request, jsonify # 引入Flask - -# 设备配置 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"当前使用的设备: {device}") - -class COCOImageClassifier: - def __init__(self, model_path: str): - """初始化分类器(移除local_image_paths参数,改为动态接收)""" - self.processor = AutoImageProcessor.from_pretrained(model_path) - self.model = AutoModelForImageClassification.from_pretrained(model_path) - self.model = self.model.to(device) - - if torch.cuda.device_count() > 1: - print(f"使用 {torch.cuda.device_count()} 块GPU") - self.model = torch.nn.DataParallel(self.model) - - self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label - - def predict_image_path(self, image_path: str, top_k: int = 5) -> dict: - """预测单张图片(复用原逻辑)""" - try: - image = Image.open(image_path).convert("RGB") - inputs = self.processor(images=image, return_tensors="pt").to(device) - - with torch.no_grad(): - outputs = self.model(** inputs) - - logits = outputs.logits - probs = torch.nn.functional.softmax(logits, dim=1) - top_probs, top_indices = probs.topk(top_k, dim=1) - - predictions = [] - for i in range(top_k): - class_idx = top_indices[0, i].item() - predictions.append({ - "class_id": class_idx, - "class_name": self.id2label[class_idx], - "confidence": top_probs[0, i].item() - }) - - return { - "image_path": image_path, - "predictions": predictions - } - except Exception as e: - print(f"处理图片 {image_path} 出错: {e}") - return None - - def batch_predict_and_evaluate(self, image_paths: list, true_labels: dict, top_k: int = 3) -> dict: - """批量预测并计算准确率、召回率""" - results = [] - start_time = time.time() - - for image_path in tqdm(image_paths): - result = self.predict_image_path(image_path, top_k) - if result: - results.append(result) - - end_time = time.time() - total_time = end_time - start_time - images_per_second = len(results) / total_time if total_time > 0 else 0 - - # 计算准确率和召回率(复用原逻辑) - correct_count = 0 - total_count = len(results) - class_actual_count = {} - class_correct_count = {} - - for prediction in results: - image_path = prediction['image_path'] - top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence']) - predicted_class = top1_prediction['class_name'].lower() - true_class = true_labels.get(image_path, "").lower() - - # 统计每个类别的实际样本数 - class_actual_count[true_class] = class_actual_count.get(true_class, 0) + 1 - - # 检查预测是否正确 - words = predicted_class.split() - for word in words: - if true_class in word: - correct_count += 1 - class_correct_count[true_class] = class_correct_count.get(true_class, 0) + 1 - break - - # 计算指标 - accuracy = correct_count / total_count if total_count > 0 else 0 - recall_per_class = {} - for class_name in class_actual_count: - recall_per_class[class_name] = class_correct_count.get(class_name, 0) / class_actual_count[class_name] - - average_recall = sum(recall_per_class.values()) / len(recall_per_class) if recall_per_class else 0 - - # 返回包含指标的结果 - return { - "status": "success", - "metrics": { - "accuracy": round(accuracy * 100, 2), # 百分比 - "average_recall": round(average_recall * 100, 2), # 百分比 - "total_images": total_count, - "correct_predictions": correct_count, - "speed_images_per_second": round(images_per_second, 2) - }, - "sample_predictions": results[:3] # 示例预测结果(可选) - } - -# 初始化Flask服务 -app = Flask(__name__) -MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 容器内模型路径 -DATASET_PATH = os.environ.get("DATASET_PATH", "/app/dataset") # 容器内数据集路径 -classifier = COCOImageClassifier(MODEL_PATH) - -@app.route('/v1/private/s782b4996', methods=['POST']) -def evaluate(): - """接收请求并返回评估结果(准确率、召回率等)""" - try: - # 解析请求参数(可选:允许动态指定limit等参数) - data = request.get_json() - limit = data.get("limit", 20) # 限制处理的图片数量 - - # 加载数据集(容器内路径) - local_image_paths = [] - true_labels = {} - for folder in os.listdir(DATASET_PATH): - folder_path = os.path.join(DATASET_PATH, folder) - if os.path.isdir(folder_path): - class_name = folder.split('.', 1)[1] - image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))] - selected_images = random.sample(image_files, min(3, len(image_files))) - for image_path in selected_images: - local_image_paths.append(image_path) - true_labels[image_path] = class_name - - # 限制处理数量 - local_image_paths = local_image_paths[:limit] - - # 执行预测和评估 - result = classifier.batch_predict_and_evaluate(local_image_paths, true_labels, top_k=3) - return jsonify(result) - - except Exception as e: - return jsonify({ - "status": "error", - "message": str(e) - }), 500 - -@app.route('/health', methods=['GET']) -def health_check(): - return jsonify({"status": "healthy", "device": str(device)}), 200 - -if __name__ == "__main__": - app.run(host='0.0.0.0', port=8000, debug=False) \ No newline at end of file diff --git a/model_test_caltech_http_3.py b/model_test_caltech_http_3.py deleted file mode 100644 index 803e884..0000000 --- a/model_test_caltech_http_3.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -from PIL import Image -from transformers import AutoImageProcessor, AutoModelForImageClassification -import os -from flask import Flask, request, jsonify -from io import BytesIO - -# 设备配置 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"当前使用的设备: {device}") - -class ImageClassifier: - def __init__(self, model_path: str): - # 获取模型路径下的第一个子目录(假设模型文件存放在这里) - subdirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))] - if not subdirs: - raise ValueError(f"在 {model_path} 下未找到任何子目录,无法加载模型") - - # 实际的模型文件路径 - actual_model_path = os.path.join(model_path, subdirs[0]) - print(f"加载模型从: {actual_model_path}") - - self.processor = AutoImageProcessor.from_pretrained(actual_model_path) - self.model = AutoModelForImageClassification.from_pretrained(actual_model_path) - self.model = self.model.to(device) - - if torch.cuda.device_count() > 1: - print(f"使用 {torch.cuda.device_count()} 块GPU") - self.model = torch.nn.DataParallel(self.model) - - self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label - - def predict_single_image(self, image) -> dict: - """预测单张图片,返回置信度最高的结果""" - try: - # 处理图片 - inputs = self.processor(images=image, return_tensors="pt").to(device) - - with torch.no_grad(): - outputs = self.model(** inputs) - - logits = outputs.logits - probs = torch.nn.functional.softmax(logits, dim=1) - # 获取置信度最高的预测结果 - max_prob, max_idx = probs.max(dim=1) - class_idx = max_idx.item() - - return { - "status": "success", - "top_prediction": { - "class_id": class_idx, - "class_name": self.id2label[class_idx], - "confidence": max_prob.item() - } - } - except Exception as e: - return { - "status": "error", - "message": str(e) - } - -# 初始化服务 -app = Flask(__name__) -MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型根路径(环境变量或默认路径) -classifier = ImageClassifier(MODEL_PATH) - -@app.route('/v1/private/s782b4996', methods=['POST']) -def predict_single(): - """接收单张图片并返回最高置信度预测结果""" - # 检查是否有图片上传 - if 'image' not in request.files: - return jsonify({"status": "error", "message": "请求中未包含图片"}), 400 - - image_file = request.files['image'] - try: - # 读取图片 - image = Image.open(BytesIO(image_file.read())).convert("RGB") - # 预测 - result = classifier.predict_single_image(image) - return jsonify(result) - except Exception as e: - return jsonify({"status": "error", "message": str(e)}), 500 - -@app.route('/health', methods=['GET']) -def health_check(): - return jsonify({"status": "healthy", "device": str(device)}), 200 - -if __name__ == "__main__": - app.run(host='0.0.0.0', port=8000, debug=False) \ No newline at end of file diff --git a/model_test_caltech_http_cuda.py b/model_test_caltech_http_cuda.py deleted file mode 100644 index 2fec98c..0000000 --- a/model_test_caltech_http_cuda.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -from PIL import Image -from transformers import AutoImageProcessor, AutoModelForImageClassification -import os -from flask import Flask, request, jsonify -from io import BytesIO - -# 设备配置 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"当前使用的设备: {device}") - -class ImageClassifier: - def __init__(self, model_path: str): - self.processor = AutoImageProcessor.from_pretrained(model_path) - self.model = AutoModelForImageClassification.from_pretrained(model_path) - self.model = self.model.to(device) - - if torch.cuda.device_count() > 1: - print(f"使用 {torch.cuda.device_count()} 块GPU") - self.model = torch.nn.DataParallel(self.model) - - self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label - - def predict_single_image(self, image) -> dict: - """预测单张图片,返回置信度最高的结果""" - try: - # 处理图片 - inputs = self.processor(images=image, return_tensors="pt").to(device) - - with torch.no_grad(): - outputs = self.model(**inputs) - - logits = outputs.logits - probs = torch.nn.functional.softmax(logits, dim=1) - # 获取置信度最高的预测结果 - max_prob, max_idx = probs.max(dim=1) - class_idx = max_idx.item() - - return { - "status": "success", - "top_prediction": { - "class_id": class_idx, - "class_name": self.id2label[class_idx], - "confidence": max_prob.item() - } - } - except Exception as e: - return { - "status": "error", - "message": str(e) - } - -# 初始化服务 -app = Flask(__name__) -MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径) -classifier = ImageClassifier(MODEL_PATH) - -@app.route('/v1/private/s782b4996', methods=['POST']) -def predict_single(): - """接收单张图片并返回最高置信度预测结果""" - # 检查是否有图片上传 - if 'image' not in request.files: - return jsonify({"status": "error", "message": "请求中未包含图片"}), 400 - - image_file = request.files['image'] - try: - # 读取图片 - image = Image.open(BytesIO(image_file.read())).convert("RGB") - # 预测 - result = classifier.predict_single_image(image) - return jsonify(result) - except Exception as e: - return jsonify({"status": "error", "message": str(e)}), 500 - -@app.route('/health', methods=['GET']) -def health_check(): - return jsonify({"status": "healthy", "device": str(device)}), 200 - -if __name__ == "__main__": - app.run(host='0.0.0.0', port=80, debug=False) \ No newline at end of file diff --git a/run_callback_cuda.py b/run_callback_cuda.py deleted file mode 100644 index a5dc1dc..0000000 --- a/run_callback_cuda.py +++ /dev/null @@ -1,1193 +0,0 @@ -import json -import os -import sys -import time -import tempfile -import zipfile -import threading -from collections import defaultdict -from typing import Dict, List - -import yaml -from pydantic import ValidationError - -from schemas.dataset import QueryData -from utils.client_callback import ClientCallback, EvaluateResult, StopException -from utils.logger import log -from utils.service import register_sut -from utils.update_submit import change_product_available -from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml -from utils.leaderboard import change_product_unavailable - - -lck = threading.Lock() - -# Environment variables by leaderboard -DATASET_FILEPATH = os.environ["DATASET_FILEPATH"] -RESULT_FILEPATH = os.environ["RESULT_FILEPATH"] - -DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"] -SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"] -BENCHMARK_NAME = os.environ["BENCHMARK_NAME"] -TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1)) -THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8)) - -log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}") -workspace_path = "/tmp/workspace" - - -# Environment variables by kubernetes -MY_POD_IP = os.environ["MY_POD_IP"] - -# constants -RESOURCE_NAME = BENCHMARK_NAME - -# Environment variables by judge_flow_config -LANG = os.getenv("lang") -SUT_CPU = os.getenv("SUT_CPU", "2") -SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi") -SUT_VGPU = os.getenv("SUT_VGPU", "1") -#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU))) -#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU))) -SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100") -RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu") -assert RESOURCE_TYPE in [ - "cpu", - "vgpu", -], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu" - - -unzip_dir(DATASET_FILEPATH, workspace_path) - -def get_sut_url_kubernetes(): - with open(SUBMIT_CONFIG_FILEPATH, "r") as f: - submit_config = yaml.safe_load(f) - assert isinstance(submit_config, dict) - - submit_config.setdefault("values", {}) - - submit_config["values"]["containers"] = [ - { - "name": "corex-container", - "image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像 - "command": ["sleep"], # 替换为你的模型启动命令,使用python解释器 - "args": ["3600"], # 替换为你的模型参数,运行我的推理脚本 - - # 添加存储卷挂载 - #"volumeMounts": [ - # { - # "name": "model-volume", - # "mountPath": "/model" # 挂载到/model目录 - # } - #] - } - ] - - """ - # 添加存储卷配置 - submit_config["values"]["volumes"] = [ - { - "name": "model-volume", - "persistentVolumeClaim": { - "claimName": "sid-model-pvc" # 使用已有的PVC - } - } - ] - """ - - """ - # Inject specified cpu and memory - resource = { - "cpu": SUT_CPU, - "memory": SUT_MEMORY, - } - """ - submit_config["values"]["resources"] = { - "requests":{}, - "limits": {}, - } - - limits = submit_config["values"]["resources"]["limits"] - requests = submit_config["values"]["resources"]["requests"] - - - """ - # ########## 关键修改:替换为iluvatar GPU配置 ########## - if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU - # 替换nvidia资源键为iluvatar.ai/gpu - vgpu_resource = { - "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 - # 若需要其他资源(如显存),按你的K8s配置补充,例如: - # "iluvatar.ai/gpumem": SUT_VGPU_MEM, - } - limits.update(vgpu_resource) - requests.update(vgpu_resource) - # 节点选择器:替换为你的accelerator标签 - submit_config["values"]["nodeSelector"] = { - "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 - } - # 容忍度:替换为你的tolerations配置 - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - } - ] - # ######################################### - # 禁止CPU模式下使用GPU资源(保持原逻辑) - else: - if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests: - log.error("禁止在CPU模式下使用GPU资源") - sys.exit(1) - - - - #gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键 - #for key in gpukeys: - # if key in limits or key in requests: - # log.error("禁止使用vgpu资源") - # sys.exit(1) - - """ - - # 替换nvidia资源键为iluvatar.ai/gpu - vgpu_resource = { - "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 - # 若需要其他资源(如显存),按你的K8s配置补充,例如: - # "iluvatar.ai/gpumem": SUT_VGPU_MEM, - } - limits.update(vgpu_resource) - requests.update(vgpu_resource) - # 节点选择器:替换为你的accelerator标签 - submit_config["values"]["nodeSelector"] = { - "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 - } - # 容忍度:替换为你的tolerations配置 - """ - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "arm64", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "myinit", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "middleware", - "effect": "NoSchedule", - } - - ] - """ - """ - { - "key": "node-role.kubernetes.io/master", - "operator": "Exists", - "effect": "NoSchedule", - }, - { - "key": "node.kubernetes.io/not-ready", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - }, - { - "key": "node.kubernetes.io/unreachable", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - } - """ - - - log.info(f"submit_config: {submit_config}") - log.info(f"RESOURCE_NAME: {RESOURCE_NAME}") - - return register_sut(submit_config, RESOURCE_NAME).replace( - "ws://", "http://" - ) - - -def get_sut_url(): - return get_sut_url_kubernetes() - -#SUT_URL = get_sut_url() -#os.environ["SUT_URL"] = SUT_URL - -""" -def load_dataset( - dataset_filepath: str, -) -> Dict[str, List[QueryData]]: - dataset_path = tempfile.mkdtemp() - - with zipfile.ZipFile(dataset_filepath) as zf: - zf.extractall(dataset_path) - - basename = os.path.basename(dataset_filepath) - datayaml = os.path.join(dataset_path, "data.yaml") - if not os.path.exists(datayaml): - sub_dataset_paths = os.listdir(dataset_path) - dataset = {} - for sub_dataset_path in sub_dataset_paths: - sub_dataset = load_dataset( - os.path.join(dataset_path, sub_dataset_path) - ) - for k, v in sub_dataset.items(): - k = os.path.join(basename, k) - dataset[k] = v - return dataset - - with open(datayaml, "r") as f: - data = yaml.safe_load(f) - assert isinstance(data, dict) - - lang = LANG - data_lang = data.get("global", {}).get("lang") - if lang is None and data_lang is not None: - if data_lang is not None: - # 使用配置中的语言类型 - lang = data_lang - if lang is None and basename.startswith("asr.") and len(basename) == 4 + 2: - # 数据集名称为asr.en 可以认为语言为en - lang = basename[4:] - if lang is None: - log.error( - "数据集错误 通过data.yaml中的 global.lang 或 数据集名称 asr.xx 指定语言类型" - ) - sys.exit(1) - - query_data = data.get("query_data", []) - audio_size_map = {} - for query in query_data: - query["lang"] = lang - query["file"] = os.path.join(dataset_path, query["file"]) - audio_size_map[query["file"]] = os.path.getsize(query["file"]) - # 根据音频大小排序 - query_data = sorted( - query_data, key=lambda x: audio_size_map[x["file"]], reverse=True - ) - valid_query_data = [] - for i_query_data in query_data: - try: - valid_query_data.append(QueryData.model_validate(i_query_data)) - except ValidationError: - log.error("数据集错误 数据中query_data格式错误") - sys.exit(1) - - return { - basename: valid_query_data, - } - - -def merge_query_data(dataset: Dict[str, List[QueryData]]) -> List[QueryData]: - query_datas = [] - for query_data in dataset.values(): - query_datas.extend(query_data) - return query_datas - - -def run_one_predict( - client: ClientCallback, query_data: QueryData, task_id: str -) -> EvaluateResult: - try: - client.predict(None, query_data.file, query_data.duration, task_id) - except StopException: - sys.exit(1) - - client.finished.wait() - - if client.error is not None: - sys.exit(1) - - client.app_on = False - - try: - with lck: - ret = client.evaluate(query_data) - return ret - except StopException: - sys.exit(1) -""" - -""" -def predict_task( - client: ClientCallback, task_id: int, query_data: QueryData, test_results: list -): - log.info(f"Task-{task_id}开始评测") - test_results[task_id] = run_one_predict(client, query_data, str(task_id)) - - -def merge_concurrent_result(evaluate_results: List[EvaluateResult]) -> Dict: - cer = 0.0 - align_start = {} - align_end = {} - first_word_distance_sum = 0.0 - last_word_distance_sum = 0.0 - rtf = 0.0 - first_receive_delay: float = 0.0 - query_count: int = 0 - voice_count: int = 0 - pred_punctuation_num: int = 0 - label_punctuation_num: int = 0 - pred_sentence_punctuation_num: int = 0 - label_setence_punctuation_num: int = 0 - - for evalute_result in evaluate_results: - cer += evalute_result.cer - for k, v in evalute_result.align_start.items(): - align_start.setdefault(k, 0) - align_start[k] += v - for k, v in evalute_result.align_end.items(): - align_end.setdefault(k, 0) - align_end[k] += v - first_word_distance_sum += evalute_result.first_word_distance_sum - last_word_distance_sum += evalute_result.last_word_distance_sum - rtf += evalute_result.rtf - first_receive_delay += evalute_result.first_receive_delay - query_count += evalute_result.query_count - voice_count += evalute_result.voice_count - pred_punctuation_num += evalute_result.pred_punctuation_num - label_punctuation_num += evalute_result.label_punctuation_num - pred_sentence_punctuation_num += ( - evalute_result.pred_sentence_punctuation_num - ) - label_setence_punctuation_num += ( - evalute_result.label_setence_punctuation_num - ) - lens = len(evaluate_results) - cer /= lens - for k, v in align_start.items(): - align_start[k] /= voice_count - for k, v in align_end.items(): - align_end[k] /= voice_count - first_word_distance = first_word_distance_sum / voice_count - last_word_distance = last_word_distance_sum / voice_count - rtf /= lens - first_receive_delay /= lens - json_result = { - "one_minus_cer": 1 - cer, - "first_word_distance_mean": first_word_distance, - "last_word_distance_mean": last_word_distance, - "query_count": query_count // lens, - "voice_count": voice_count // lens, - "rtf": rtf, - "first_receive_delay": first_receive_delay, - "punctuation_ratio": ( - pred_punctuation_num / label_punctuation_num - if label_punctuation_num > 0 - else 1.0 - ), - "sentence_punctuation_ratio": ( - pred_sentence_punctuation_num / label_setence_punctuation_num - if label_setence_punctuation_num > 0 - else 1.0 - ), - } - for k, v in align_start.items(): - json_result["start_word_%dms_ratio" % k] = v - for k, v in align_end.items(): - json_result["end_word_%dms_ratio" % k] = v - - return json_result - - -def merge_result(result: Dict[str, List[EvaluateResult]]) -> Dict: - json_result = {} - for lang, evaluate_results in result.items(): - if len(evaluate_results) == 0: - continue - cer = 0.0 - align_start = {} - align_end = {} - first_word_distance_sum = 0.0 - last_word_distance_sum = 0.0 - rtf = 0.0 - first_receive_delay: float = 0.0 - query_count: int = 0 - voice_count: int = 0 - pred_punctuation_num: int = 0 - label_punctuation_num: int = 0 - pred_sentence_punctuation_num: int = 0 - label_setence_punctuation_num: int = 0 - for evalute_result in evaluate_results: - cer += evalute_result.cer - for k, v in evalute_result.align_start.items(): - align_start.setdefault(k, 0) - align_start[k] += v - for k, v in evalute_result.align_end.items(): - align_end.setdefault(k, 0) - align_end[k] += v - first_word_distance_sum += evalute_result.first_word_distance_sum - last_word_distance_sum += evalute_result.last_word_distance_sum - rtf += evalute_result.rtf - first_receive_delay += evalute_result.first_receive_delay - query_count += evalute_result.query_count - voice_count += evalute_result.voice_count - pred_punctuation_num += evalute_result.pred_punctuation_num - label_punctuation_num += evalute_result.label_punctuation_num - pred_sentence_punctuation_num += ( - evalute_result.pred_sentence_punctuation_num - ) - label_setence_punctuation_num += ( - evalute_result.label_setence_punctuation_num - ) - lens = len(evaluate_results) - cer /= lens - for k, v in align_start.items(): - align_start[k] /= voice_count - for k, v in align_end.items(): - align_end[k] /= voice_count - first_word_distance = first_word_distance_sum / voice_count - last_word_distance = last_word_distance_sum / voice_count - rtf /= lens - first_receive_delay /= lens - lang_result = { - "one_minus_cer": 1 - cer, - "first_word_distance_mean": first_word_distance, - "last_word_distance_mean": last_word_distance, - "query_count": 1, - "voice_count": voice_count, - "rtf": rtf, - "first_receive_delay": first_receive_delay, - "punctuation_ratio": ( - pred_punctuation_num / label_punctuation_num - if label_punctuation_num > 0 - else 1.0 - ), - "sentence_punctuation_ratio": ( - pred_sentence_punctuation_num / label_setence_punctuation_num - if label_setence_punctuation_num > 0 - else 1.0 - ), - } - for k, v in align_start.items(): - lang_result["start_word_%dms_ratio" % k] = v - for k, v in align_end.items(): - lang_result["end_word_%dms_ratio" % k] = v - if lang == "": - json_result.update(lang_result) - else: - json_result[lang] = lang_result - return json_result -""" - -""" -def main(): - log.info(f'{TEST_CONCURRENCY=}, {THRESHOLD_OMCER=}') - dataset = load_dataset(DATASET_FILEPATH) - query_datas = merge_query_data(dataset) - - #获取 ASR 服务 URL(通常从 Kubernetes 配置) - sut_url = get_sut_url() - - #创建多个客户端实例(每个客户端监听不同端口,如 80、81、82...) - port_base = 80 - clients = [ClientCallback(sut_url, port_base + i) for i in range(TEST_CONCURRENCY)] - - #准备测试数据与线程 - detail_cases = [] - # we use the same test data for all requests - query_data = query_datas[0] - - test_results = [None] * len(clients) - test_threads = [threading.Thread(target=predict_task, args=(client, task_id, query_data, test_results)) - for task_id, client in enumerate(clients)] - - #启动并发测试,启动线程并间隔10秒,设置超时时间为1小时 - for t in test_threads: - t.start() - time.sleep(10) - [t.join(timeout=3600) for t in test_threads] - - #合并结果与评估 - final_result = merge_concurrent_result(test_results) - product_avaiable = all([c.product_avaiable for c in clients]) - - final_result['concurrent_req'] = TEST_CONCURRENCY - if final_result['one_minus_cer'] < THRESHOLD_OMCER: - product_avaiable = False - - if not product_avaiable: - final_result['success'] = False - change_product_available() - else: - final_result['success'] = True - - #保存结果, - log.info( - "指标结果为: %s", json.dumps(final_result, indent=2, ensure_ascii=False) - ) - - time.sleep(120) - #打印并保存最终结果到文件 - with open(RESULT_FILEPATH, "w") as f: - json.dump(final_result, f, indent=2, ensure_ascii=False) - #保存详细测试用例结果 - with open(DETAILED_CASES_FILEPATH, "w") as f: - json.dump(detail_cases, f, indent=2, ensure_ascii=False) -""" - -############################################################################# - -import requests -import base64 - -def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): - """ - 生成请求的body - :param apiname - :param APPId: Appid - :param file_name: 文件路径 - :return: - """ - if apiname == 'createFeature': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "createFeature", - "groupId": "test_voiceprint_e", - "featureId": featureId, - "featureInfo": featureInfo, - "createFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'createGroup': - - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "createGroup", - "groupId": "test_voiceprint_e", - "groupName": "vip_user", - "groupInfo": "store_vip_user_voiceprint", - "createGroupRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'deleteFeature': - - body = { - "header": { - "app_id": APPId, - "status": 3 - - }, - "parameter": { - "s782b4996": { - "func": "deleteFeature", - "groupId": "iFLYTEK_examples_groupId", - "featureId": "iFLYTEK_examples_featureId", - "deleteFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'queryFeatureList': - - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "queryFeatureList", - "groupId": "user_voiceprint_2", - "queryFeatureListRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'searchFea': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "searchFea", - "groupId": "test_voiceprint_e", - "topK": 1, - "searchFeaRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'searchScoreFea': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "searchScoreFea", - "groupId": "test_voiceprint_e", - "dstFeatureId": dstFeatureId, - "searchScoreFeaRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'updateFeature': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "updateFeature", - "groupId": "iFLYTEK_examples_groupId", - "featureId": "iFLYTEK_examples_featureId", - "featureInfo": "iFLYTEK_examples_featureInfo_update", - "updateFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'deleteGroup': - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "deleteGroup", - "groupId": "iFLYTEK_examples_groupId", - "deleteGroupRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - else: - raise Exception( - "输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查") - return body - - - -log.info(f"开始请求获取到SUT服务URL") -# 获取SUT服务URL -sut_url = get_sut_url() -print(f"获取到的SUT_URL: {sut_url}") # 调试输出 -log.info(f"获取到SUT服务URL: {sut_url}") - -from urllib.parse import urlparse - -# 全局变量 -text_decoded = None - -###################################新增新增################################ -def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): - """ - 开始请求 - :param APPId: APPID - :param file_path: body里的文件路径 - :return: - """ - - global text_decoded - - body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId) - #request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996' - - #request_url = 'https://sut:80/sid/v1/private/s782b4996' - - #headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId} - - parsed_url = urlparse(sut_url) - headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId} - - # 1. 首先测试服务健康检查 - response = requests.get(f"{sut_url}/health") - print(response.status_code, response.text) - - - # 请求头 - headers = {"Content-Type": "application/json"} - # 请求体(可指定限制处理的图片数量) - body = {"limit": 20 } # 可选参数,限制处理的图片总数 - - # 发送POST请求 - response = requests.post( - f"{sut_url}/v1/private/s782b4996", - data=json.dumps(body), - headers=headers - ) - - # 解析响应结果 - if response.status_code == 200: - result = response.json() - print("预测评估结果:") - print(f"准确率: {result['metrics']['accuracy']}%") - print(f"平均召回率: {result['metrics']['average_recall']}%") - print(f"处理图片总数: {result['metrics']['total_images']}") - else: - print(f"请求失败,状态码: {response.status_code}") - print(f"错误信息: {response.text}") - - - - - # 添加基本认证信息 - auth = ('llm', 'Rmf4#LcG(iFZrjU;2J') - #response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth) - - #response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth) - #response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth) - """ - response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers) - """ - - - - - #print("HTTP状态码:", response.status_code) - #print("原始响应内容:", response.text) # 先打印原始内容 - #print(f"请求URL: {sut_url + '/v1/private/s782b4996'}") - #print(f"请求headers: {headers}") - #print(f"请求body: {body}") - - - - #tempResult = json.loads(response.content.decode('utf-8')) - #print(tempResult) - - """ - # 对text字段进行Base64解码 - if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']: - text_encoded = tempResult['payload']['updateFeatureRes']['text'] - text_decoded = base64.b64decode(text_encoded).decode('utf-8') - print(f"Base64解码后的text字段内容: {text_decoded}") - """ - - #text_encoded = tempResult['payload']['updateFeatureRes']['text'] - #text_decoded = base64.b64decode(text_encoded).decode('utf-8') - #print(f"Base64解码后的text字段内容: {text_decoded}") - - - # 获取响应的 JSON 数据 - result = response.json() - with open(RESULT_FILEPATH, "w") as f: - json.dump(result, f, indent=4, ensure_ascii=False) - print(f"结果已成功写入 {RESULT_FILEPATH}") - -submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config") -result_filepath = os.getenv("RESULT_FILEPATH", "./out/result") -bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase") -#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl") - -from typing import Any, Dict, List - -def result2file( - result: Dict[str, Any], - detail_cases: List[Dict[str, Any]] = None -): - assert result_filepath is not None - assert bad_cases_filepath is not None - #assert detailed_cases_filepath is not None - - if result is not None: - with open(result_filepath, "w") as f: - json.dump(result, f, indent=4, ensure_ascii=False) - #if LOCAL_TEST: - # logger.info(f'result:\n {json.dumps(result, indent=4)}') - """ - if detail_cases is not None: - with open(detailed_cases_filepath, "w") as f: - json.dump(detail_cases, f, indent=4, ensure_ascii=False) - if LOCAL_TEST: - logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}') - """ - - -def test_image_prediction(sut_url, image_path): - """发送单张图片到服务端预测""" - url = f"{sut_url}/v1/private/s782b4996" - - try: - with open(image_path, 'rb') as f: - files = {'image': f} - response = requests.post(url, files=files, timeout=30) - - result = response.json() - if result.get('status') != 'success': - return None, f"服务端错误: {result.get('message')}" - - return result.get('top_prediction'), None - except Exception as e: - return None, f"请求错误: {str(e)}" - - - -import random -import time -#from tqdm import tqdm -import os -import requests - -if __name__ == '__main__': - - print(f"\n===== main开始请求接口 ===============================================") - # 1. 首先测试服务健康检查 - - print(f"\n===== 服务健康检查 ===================================================") - response = requests.get(f"{sut_url}/health") - print(response.status_code, response.text) - - """ - # 本地图片路径和真实标签(根据实际情况修改) - image_path = "/path/to/your/test_image.jpg" - true_label = "cat" # 图片的真实标签 - """ - - - """ - # 请求头 - headers = {"Content-Type": "application/json"} - # 请求体(可指定限制处理的图片数量) - body = {"limit": 20 } # 可选参数,限制处理的图片总数 - - # 发送POST请求 - response = requests.post( - f"{sut_url}/v1/private/s782b4996", - data=json.dumps(body), - headers=headers - ) - """ - - """ - # 读取图片文件 - with open(image_path, 'rb') as f: - files = {'image': f} - # 发送POST请求 - response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files) - - - # 解析响应结果 - if response.status_code == 200: - result = response.json() - print("预测评估结果:") - print(f"准确率: {result['metrics']['accuracy']}%") - print(f"平均召回率: {result['metrics']['average_recall']}%") - print(f"处理图片总数: {result['metrics']['total_images']}") - else: - print(f"请求失败,状态码: {response.status_code}") - print(f"错误信息: {response.text}") - """ - - - ############################################################################################### - dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录 - samples_per_class = 3 # 每个类别抽取的样本数 - image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式 - - # 结果统计变量 - #total_samples = 0 - correct_predictions = 0 - - # 结果统计变量 - total_samples = 0 - true_positives = 0 - false_positives = 0 - false_negatives = 0 - total_processing_time = 0.0 # 总处理时间(秒) - - """ - # 遍历所有类别文件夹 - for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"): - folder_path = os.path.join(dataset_root, folder_name) - - - # 提取类别名(从"序号.name"格式中提取name部分) - class_name = folder_name.split('.', 1)[1].strip().lower() - - # 获取文件夹中所有图片 - image_files = [] - for file in os.listdir(folder_path): - if file.lower().endswith(image_extensions): - image_files.append(os.path.join(folder_path, file)) - - # 随机抽取指定数量的图片(如果不足则取全部) - selected_images = random.sample( - image_files, - min(samples_per_class, len(image_files)) - ) - - # 处理选中的图片 - for img_path in selected_images: - total_count += 1 - - # 发送预测请求 - prediction, error = test_image_prediction(sut_url, img_path) - if error: - print(f"处理图片 {img_path} 失败: {error}") - continue - - # 解析预测结果 - pred_class = prediction.get('class_name', '').lower() - confidence = prediction.get('confidence', 0) - - # 判断是否预测正确(真实类别是否在预测类别中) - if class_name in pred_class: - correct_predictions += 1 - - - # 可选:打印详细结果 - print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") - """ - - # 遍历所有类别文件夹 - for folder_name in os.listdir(dataset_root): - folder_path = os.path.join(dataset_root, folder_name) - - # 跳过非文件夹的项目 - if not os.path.isdir(folder_path): - continue - - # 提取类别名(从"序号.name"格式中提取name部分) - try: - class_name = folder_name.split('.', 1)[1].strip().lower() - except IndexError: - print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹") - continue - - # 获取文件夹中所有图片 - image_files = [] - for file in os.listdir(folder_path): - file_path = os.path.join(folder_path, file) - if os.path.isfile(file_path) and file.lower().endswith(image_extensions): - image_files.append(file_path) - - # 随机抽取指定数量的图片(如果不足则取全部) - selected_images = random.sample( - image_files, - min(samples_per_class, len(image_files)) - ) - - # 处理该文件夹中的所有图片 - for img_path in selected_images: - total_samples += 1 - start_time = time.time() # 记录开始时间 - # 发送预测请求 - prediction, error = test_image_prediction(sut_url, img_path) - - # 计算单张图片处理时间(包括网络请求和模型预测) - processing_time = time.time() - start_time - total_processing_time += processing_time - - if error: - print(f"处理图片 {img_path} 失败: {error}") - # 处理失败的样本视为预测错误 - false_negatives += 1 - continue - - # 解析预测结果 - pred_class = prediction.get('class_name', '').lower() - confidence = prediction.get('confidence', 0) - - # 判断是否预测正确(真实类别是否在预测类别中,不分大小写) - is_correct = class_name in pred_class - - # 更新统计指标 - if is_correct: - true_positives += 1 - else: - false_positives += 1 - false_negatives += 1 - - # 打印详细结果(可选) - print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") - - """ - # 计算整体指标(在单标签场景下,准确率=召回率) - if total_samples == 0: - overall_accuracy = 0.0 - overall_recall = 0.0 - else: - overall_accuracy = correct_predictions / total_samples - overall_recall = correct_predictions / total_samples # 整体召回率 - - # 输出统计结果 - print("\n" + "="*50) - print(f"测试总结:") - print(f"总测试样本数: {total_samples}") - print(f"正确预测样本数: {correct_predictions}") - print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})") - print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})") - print("="*50) - """ - # 初始化结果字典 - result = { - "total_processing_time": round(total_processing_time, 6), - "throughput": 0.0, - "accuracy": 0.0, - "recall": 0.0 - } - - # 计算评估指标 - if total_samples == 0: - print("没有找到任何图片样本") - - - # 准确率 = 正确预测的样本数 / 总预测样本数 - accuracy = true_positives / total_samples * 100 if total_samples > 0 else 0 - - # 召回率 = 正确预测的样本数 / (正确预测的样本数 + 未正确预测的正样本数) - recall_denominator = true_positives + false_negatives - recall = true_positives / recall_denominator * 100 if recall_denominator > 0 else 0 - - # 处理速度计算(每秒钟处理的图片张数) - # 避免除以0(当总时间极短时) - throughput = total_samples / total_processing_time if total_processing_time > 1e-6 else 0 - - # 更新结果字典 - result.update({ - "throughput": round(throughput, 6), - "accuracy": round(accuracy, 6), - "recall": round(recall, 6) - }) - - # 打印最终统计结果 - print("\n" + "="*50) - print(f"总样本数: {total_samples}") - print(f"总处理时间: {total_processing_time:.4f}秒") - print(f"处理速度: {throughput:.2f}张/秒") # 新增:每秒钟处理的图片张数 - print(f"正确预测: {true_positives}") - print(f"错误预测: {total_samples - true_positives}") - print(f"准确率: {accuracy:.4f} ({true_positives}/{total_samples})") - print(f"召回率: {recall:.4f} ({true_positives}/{recall_denominator})") - print("="*50) - - - #result = {} - #result['accuracy_1_1'] = 3 - result2file(result) - - """ - if result['accuracy_1_1'] < 0.9: - log.error(f"1:1正确率未达到90%, 视为产品不可用") - change_product_unavailable() - - - if result['accuracy_1_N'] < 1: - log.error(f"1:N正确率未达到100%, 视为产品不可用") - change_product_unavailable() - if result['1_1_latency'] > 0.5: - log.error(f"1:1平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['1_N_latency'] > 0.5: - log.error(f"1:N平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['enroll_latency'] > 1: - log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用") - change_product_unavailable() - """ - exit_code = 0 - - diff --git a/run_callback_new.py b/run_callback_new.py deleted file mode 100644 index 77225b7..0000000 --- a/run_callback_new.py +++ /dev/null @@ -1,1296 +0,0 @@ -import json -import os -import sys -import time -import tempfile -import zipfile -import threading -from collections import defaultdict -from typing import Dict, List - -import yaml -from pydantic import ValidationError - -from schemas.dataset import QueryData -from utils.client_callback import ClientCallback, EvaluateResult, StopException -from utils.logger import log -from utils.service import register_sut -from utils.update_submit import change_product_available -from utils.file import dump_json, load_yaml, unzip_dir, load_json, write_file, dump_yaml -from utils.leaderboard import change_product_unavailable - - -lck = threading.Lock() - -# Environment variables by leaderboard -DATASET_FILEPATH = os.environ["DATASET_FILEPATH"] -RESULT_FILEPATH = os.environ["RESULT_FILEPATH"] - -DETAILED_CASES_FILEPATH = os.environ["DETAILED_CASES_FILEPATH"] -SUBMIT_CONFIG_FILEPATH = os.environ["SUBMIT_CONFIG_FILEPATH"] -BENCHMARK_NAME = os.environ["BENCHMARK_NAME"] -TEST_CONCURRENCY = int(os.getenv('TEST_CONCURRENCY', 1)) -THRESHOLD_OMCER = float(os.getenv('THRESHOLD_OMCER', 0.8)) - -log.info(f"DATASET_FILEPATH: {DATASET_FILEPATH}") -workspace_path = "/tmp/workspace" - - -# Environment variables by kubernetes -MY_POD_IP = os.environ["MY_POD_IP"] - -# constants -RESOURCE_NAME = BENCHMARK_NAME - -# Environment variables by judge_flow_config -LANG = os.getenv("lang") -SUT_CPU = os.getenv("SUT_CPU", "2") -SUT_MEMORY = os.getenv("SUT_MEMORY", "4Gi") -SUT_VGPU = os.getenv("SUT_VGPU", "1") -#SUT_VGPU_MEM = os.getenv("SUT_VGPU_MEM", str(1843 * int(SUT_VGPU))) -#SUT_VGPU_CORES = os.getenv("SUT_VGPU_CORES", str(8 * int(SUT_VGPU))) -SUT_VGPU_ACCELERATOR = os.getenv("SUT_VGPU_ACCELERATOR", "iluvatar-BI-V100") -RESOURCE_TYPE = os.getenv("RESOURCE_TYPE", "vgpu") -assert RESOURCE_TYPE in [ - "cpu", - "vgpu", -], "benchmark judge_flow_config error: RESOURCE_TYPE should be cpu or vgpu" - - -unzip_dir(DATASET_FILEPATH, workspace_path) - -def get_sut_url_kubernetes(): - with open(SUBMIT_CONFIG_FILEPATH, "r") as f: - submit_config = yaml.safe_load(f) - assert isinstance(submit_config, dict) - - submit_config.setdefault("values", {}) - - submit_config["values"]["containers"] = [ - { - "name": "corex-container", - "image": "harbor.4pd.io/lab-platform/inf/python:3.9", #镜像 - "command": ["sleep"], # 替换为你的模型启动命令,使用python解释器 - "args": ["3600"], # 替换为你的模型参数,运行我的推理脚本 - - # 添加存储卷挂载 - #"volumeMounts": [ - # { - # "name": "model-volume", - # "mountPath": "/model" # 挂载到/model目录 - # } - #] - } - ] - - """ - # 添加存储卷配置 - submit_config["values"]["volumes"] = [ - { - "name": "model-volume", - "persistentVolumeClaim": { - "claimName": "sid-model-pvc" # 使用已有的PVC - } - } - ] - """ - - """ - # Inject specified cpu and memory - resource = { - "cpu": SUT_CPU, - "memory": SUT_MEMORY, - } - """ - submit_config["values"]["resources"] = { - "requests":{}, - "limits": {}, - } - - limits = submit_config["values"]["resources"]["limits"] - requests = submit_config["values"]["resources"]["requests"] - - - """ - # ########## 关键修改:替换为iluvatar GPU配置 ########## - if RESOURCE_TYPE == "vgpu": # 假设你的模型需要GPU - # 替换nvidia资源键为iluvatar.ai/gpu - vgpu_resource = { - "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 - # 若需要其他资源(如显存),按你的K8s配置补充,例如: - # "iluvatar.ai/gpumem": SUT_VGPU_MEM, - } - limits.update(vgpu_resource) - requests.update(vgpu_resource) - # 节点选择器:替换为你的accelerator标签 - submit_config["values"]["nodeSelector"] = { - "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 - } - # 容忍度:替换为你的tolerations配置 - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - } - ] - # ######################################### - # 禁止CPU模式下使用GPU资源(保持原逻辑) - else: - if "iluvatar.ai/gpu" in limits or "iluvatar.ai/gpu" in requests: - log.error("禁止在CPU模式下使用GPU资源") - sys.exit(1) - - - - #gpukeys = ["iluvatar.ai/gpu"] # 检查iluvatar GPU键 - #for key in gpukeys: - # if key in limits or key in requests: - # log.error("禁止使用vgpu资源") - # sys.exit(1) - - """ - - # 替换nvidia资源键为iluvatar.ai/gpu - vgpu_resource = { - "iluvatar.ai/gpu": SUT_VGPU, # 对应你的GPU资源键 - # 若需要其他资源(如显存),按你的K8s配置补充,例如: - # "iluvatar.ai/gpumem": SUT_VGPU_MEM, - } - limits.update(vgpu_resource) - requests.update(vgpu_resource) - # 节点选择器:替换为你的accelerator标签 - submit_config["values"]["nodeSelector"] = { - "contest.4pd.io/accelerator": "iluvatar-BI-V100" # 你的节点标签 - } - # 容忍度:替换为你的tolerations配置 - """ - submit_config["values"]["tolerations"] = [ - { - "key": "hosttype", - "operator": "Equal", - "value": "iluvatar", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "arm64", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "myinit", - "effect": "NoSchedule", - }, - { - "key": "hosttype", - "operator": "Equal", - "value": "middleware", - "effect": "NoSchedule", - } - - ] - """ - """ - { - "key": "node-role.kubernetes.io/master", - "operator": "Exists", - "effect": "NoSchedule", - }, - { - "key": "node.kubernetes.io/not-ready", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - }, - { - "key": "node.kubernetes.io/unreachable", - "operator": "Exists", - "effect": "NoExecute", - "tolerationSeconds": 300 - } - """ - - - log.info(f"submit_config: {submit_config}") - log.info(f"RESOURCE_NAME: {RESOURCE_NAME}") - - return register_sut(submit_config, RESOURCE_NAME).replace( - "ws://", "http://" - ) - - -def get_sut_url(): - return get_sut_url_kubernetes() - -#SUT_URL = get_sut_url() -#os.environ["SUT_URL"] = SUT_URL - -""" -def load_dataset( - dataset_filepath: str, -) -> Dict[str, List[QueryData]]: - dataset_path = tempfile.mkdtemp() - - with zipfile.ZipFile(dataset_filepath) as zf: - zf.extractall(dataset_path) - - basename = os.path.basename(dataset_filepath) - datayaml = os.path.join(dataset_path, "data.yaml") - if not os.path.exists(datayaml): - sub_dataset_paths = os.listdir(dataset_path) - dataset = {} - for sub_dataset_path in sub_dataset_paths: - sub_dataset = load_dataset( - os.path.join(dataset_path, sub_dataset_path) - ) - for k, v in sub_dataset.items(): - k = os.path.join(basename, k) - dataset[k] = v - return dataset - - with open(datayaml, "r") as f: - data = yaml.safe_load(f) - assert isinstance(data, dict) - - lang = LANG - data_lang = data.get("global", {}).get("lang") - if lang is None and data_lang is not None: - if data_lang is not None: - # 使用配置中的语言类型 - lang = data_lang - if lang is None and basename.startswith("asr.") and len(basename) == 4 + 2: - # 数据集名称为asr.en 可以认为语言为en - lang = basename[4:] - if lang is None: - log.error( - "数据集错误 通过data.yaml中的 global.lang 或 数据集名称 asr.xx 指定语言类型" - ) - sys.exit(1) - - query_data = data.get("query_data", []) - audio_size_map = {} - for query in query_data: - query["lang"] = lang - query["file"] = os.path.join(dataset_path, query["file"]) - audio_size_map[query["file"]] = os.path.getsize(query["file"]) - # 根据音频大小排序 - query_data = sorted( - query_data, key=lambda x: audio_size_map[x["file"]], reverse=True - ) - valid_query_data = [] - for i_query_data in query_data: - try: - valid_query_data.append(QueryData.model_validate(i_query_data)) - except ValidationError: - log.error("数据集错误 数据中query_data格式错误") - sys.exit(1) - - return { - basename: valid_query_data, - } - - -def merge_query_data(dataset: Dict[str, List[QueryData]]) -> List[QueryData]: - query_datas = [] - for query_data in dataset.values(): - query_datas.extend(query_data) - return query_datas - - -def run_one_predict( - client: ClientCallback, query_data: QueryData, task_id: str -) -> EvaluateResult: - try: - client.predict(None, query_data.file, query_data.duration, task_id) - except StopException: - sys.exit(1) - - client.finished.wait() - - if client.error is not None: - sys.exit(1) - - client.app_on = False - - try: - with lck: - ret = client.evaluate(query_data) - return ret - except StopException: - sys.exit(1) -""" - -""" -def predict_task( - client: ClientCallback, task_id: int, query_data: QueryData, test_results: list -): - log.info(f"Task-{task_id}开始评测") - test_results[task_id] = run_one_predict(client, query_data, str(task_id)) - - -def merge_concurrent_result(evaluate_results: List[EvaluateResult]) -> Dict: - cer = 0.0 - align_start = {} - align_end = {} - first_word_distance_sum = 0.0 - last_word_distance_sum = 0.0 - rtf = 0.0 - first_receive_delay: float = 0.0 - query_count: int = 0 - voice_count: int = 0 - pred_punctuation_num: int = 0 - label_punctuation_num: int = 0 - pred_sentence_punctuation_num: int = 0 - label_setence_punctuation_num: int = 0 - - for evalute_result in evaluate_results: - cer += evalute_result.cer - for k, v in evalute_result.align_start.items(): - align_start.setdefault(k, 0) - align_start[k] += v - for k, v in evalute_result.align_end.items(): - align_end.setdefault(k, 0) - align_end[k] += v - first_word_distance_sum += evalute_result.first_word_distance_sum - last_word_distance_sum += evalute_result.last_word_distance_sum - rtf += evalute_result.rtf - first_receive_delay += evalute_result.first_receive_delay - query_count += evalute_result.query_count - voice_count += evalute_result.voice_count - pred_punctuation_num += evalute_result.pred_punctuation_num - label_punctuation_num += evalute_result.label_punctuation_num - pred_sentence_punctuation_num += ( - evalute_result.pred_sentence_punctuation_num - ) - label_setence_punctuation_num += ( - evalute_result.label_setence_punctuation_num - ) - lens = len(evaluate_results) - cer /= lens - for k, v in align_start.items(): - align_start[k] /= voice_count - for k, v in align_end.items(): - align_end[k] /= voice_count - first_word_distance = first_word_distance_sum / voice_count - last_word_distance = last_word_distance_sum / voice_count - rtf /= lens - first_receive_delay /= lens - json_result = { - "one_minus_cer": 1 - cer, - "first_word_distance_mean": first_word_distance, - "last_word_distance_mean": last_word_distance, - "query_count": query_count // lens, - "voice_count": voice_count // lens, - "rtf": rtf, - "first_receive_delay": first_receive_delay, - "punctuation_ratio": ( - pred_punctuation_num / label_punctuation_num - if label_punctuation_num > 0 - else 1.0 - ), - "sentence_punctuation_ratio": ( - pred_sentence_punctuation_num / label_setence_punctuation_num - if label_setence_punctuation_num > 0 - else 1.0 - ), - } - for k, v in align_start.items(): - json_result["start_word_%dms_ratio" % k] = v - for k, v in align_end.items(): - json_result["end_word_%dms_ratio" % k] = v - - return json_result - - -def merge_result(result: Dict[str, List[EvaluateResult]]) -> Dict: - json_result = {} - for lang, evaluate_results in result.items(): - if len(evaluate_results) == 0: - continue - cer = 0.0 - align_start = {} - align_end = {} - first_word_distance_sum = 0.0 - last_word_distance_sum = 0.0 - rtf = 0.0 - first_receive_delay: float = 0.0 - query_count: int = 0 - voice_count: int = 0 - pred_punctuation_num: int = 0 - label_punctuation_num: int = 0 - pred_sentence_punctuation_num: int = 0 - label_setence_punctuation_num: int = 0 - for evalute_result in evaluate_results: - cer += evalute_result.cer - for k, v in evalute_result.align_start.items(): - align_start.setdefault(k, 0) - align_start[k] += v - for k, v in evalute_result.align_end.items(): - align_end.setdefault(k, 0) - align_end[k] += v - first_word_distance_sum += evalute_result.first_word_distance_sum - last_word_distance_sum += evalute_result.last_word_distance_sum - rtf += evalute_result.rtf - first_receive_delay += evalute_result.first_receive_delay - query_count += evalute_result.query_count - voice_count += evalute_result.voice_count - pred_punctuation_num += evalute_result.pred_punctuation_num - label_punctuation_num += evalute_result.label_punctuation_num - pred_sentence_punctuation_num += ( - evalute_result.pred_sentence_punctuation_num - ) - label_setence_punctuation_num += ( - evalute_result.label_setence_punctuation_num - ) - lens = len(evaluate_results) - cer /= lens - for k, v in align_start.items(): - align_start[k] /= voice_count - for k, v in align_end.items(): - align_end[k] /= voice_count - first_word_distance = first_word_distance_sum / voice_count - last_word_distance = last_word_distance_sum / voice_count - rtf /= lens - first_receive_delay /= lens - lang_result = { - "one_minus_cer": 1 - cer, - "first_word_distance_mean": first_word_distance, - "last_word_distance_mean": last_word_distance, - "query_count": 1, - "voice_count": voice_count, - "rtf": rtf, - "first_receive_delay": first_receive_delay, - "punctuation_ratio": ( - pred_punctuation_num / label_punctuation_num - if label_punctuation_num > 0 - else 1.0 - ), - "sentence_punctuation_ratio": ( - pred_sentence_punctuation_num / label_setence_punctuation_num - if label_setence_punctuation_num > 0 - else 1.0 - ), - } - for k, v in align_start.items(): - lang_result["start_word_%dms_ratio" % k] = v - for k, v in align_end.items(): - lang_result["end_word_%dms_ratio" % k] = v - if lang == "": - json_result.update(lang_result) - else: - json_result[lang] = lang_result - return json_result -""" - -""" -def main(): - log.info(f'{TEST_CONCURRENCY=}, {THRESHOLD_OMCER=}') - dataset = load_dataset(DATASET_FILEPATH) - query_datas = merge_query_data(dataset) - - #获取 ASR 服务 URL(通常从 Kubernetes 配置) - sut_url = get_sut_url() - - #创建多个客户端实例(每个客户端监听不同端口,如 80、81、82...) - port_base = 80 - clients = [ClientCallback(sut_url, port_base + i) for i in range(TEST_CONCURRENCY)] - - #准备测试数据与线程 - detail_cases = [] - # we use the same test data for all requests - query_data = query_datas[0] - - test_results = [None] * len(clients) - test_threads = [threading.Thread(target=predict_task, args=(client, task_id, query_data, test_results)) - for task_id, client in enumerate(clients)] - - #启动并发测试,启动线程并间隔10秒,设置超时时间为1小时 - for t in test_threads: - t.start() - time.sleep(10) - [t.join(timeout=3600) for t in test_threads] - - #合并结果与评估 - final_result = merge_concurrent_result(test_results) - product_avaiable = all([c.product_avaiable for c in clients]) - - final_result['concurrent_req'] = TEST_CONCURRENCY - if final_result['one_minus_cer'] < THRESHOLD_OMCER: - product_avaiable = False - - if not product_avaiable: - final_result['success'] = False - change_product_available() - else: - final_result['success'] = True - - #保存结果, - log.info( - "指标结果为: %s", json.dumps(final_result, indent=2, ensure_ascii=False) - ) - - time.sleep(120) - #打印并保存最终结果到文件 - with open(RESULT_FILEPATH, "w") as f: - json.dump(final_result, f, indent=2, ensure_ascii=False) - #保存详细测试用例结果 - with open(DETAILED_CASES_FILEPATH, "w") as f: - json.dump(detail_cases, f, indent=2, ensure_ascii=False) -""" - -############################################################################# - -import requests -import base64 - -def gen_req_body(apiname, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): - """ - 生成请求的body - :param apiname - :param APPId: Appid - :param file_name: 文件路径 - :return: - """ - if apiname == 'createFeature': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "createFeature", - "groupId": "test_voiceprint_e", - "featureId": featureId, - "featureInfo": featureInfo, - "createFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'createGroup': - - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "createGroup", - "groupId": "test_voiceprint_e", - "groupName": "vip_user", - "groupInfo": "store_vip_user_voiceprint", - "createGroupRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'deleteFeature': - - body = { - "header": { - "app_id": APPId, - "status": 3 - - }, - "parameter": { - "s782b4996": { - "func": "deleteFeature", - "groupId": "iFLYTEK_examples_groupId", - "featureId": "iFLYTEK_examples_featureId", - "deleteFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'queryFeatureList': - - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "queryFeatureList", - "groupId": "user_voiceprint_2", - "queryFeatureListRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - elif apiname == 'searchFea': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "searchFea", - "groupId": "test_voiceprint_e", - "topK": 1, - "searchFeaRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'searchScoreFea': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "searchScoreFea", - "groupId": "test_voiceprint_e", - "dstFeatureId": dstFeatureId, - "searchScoreFeaRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'updateFeature': - - with open(file_path, "rb") as f: - audioBytes = f.read() - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "updateFeature", - "groupId": "iFLYTEK_examples_groupId", - "featureId": "iFLYTEK_examples_featureId", - "featureInfo": "iFLYTEK_examples_featureInfo_update", - "updateFeatureRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - }, - "payload": { - "resource": { - "encoding": "lame", - "sample_rate": 16000, - "channels": 1, - "bit_depth": 16, - "status": 3, - "audio": str(base64.b64encode(audioBytes), 'UTF-8') - } - } - } - elif apiname == 'deleteGroup': - body = { - "header": { - "app_id": APPId, - "status": 3 - }, - "parameter": { - "s782b4996": { - "func": "deleteGroup", - "groupId": "iFLYTEK_examples_groupId", - "deleteGroupRes": { - "encoding": "utf8", - "compress": "raw", - "format": "json" - } - } - } - } - else: - raise Exception( - "输入的apiname不在[createFeature, createGroup, deleteFeature, queryFeatureList, searchFea, searchScoreFea,updateFeature]内,请检查") - return body - - - -log.info(f"开始请求获取到SUT服务URL") -# 获取SUT服务URL -sut_url = get_sut_url() -print(f"获取到的SUT_URL: {sut_url}") # 调试输出 -log.info(f"获取到SUT服务URL: {sut_url}") - -from urllib.parse import urlparse - -# 全局变量 -text_decoded = None - -###################################新增新增################################ -def req_url(api_name, APPId, file_path=None, featureId=None, featureInfo=None, dstFeatureId=None): - """ - 开始请求 - :param APPId: APPID - :param file_path: body里的文件路径 - :return: - """ - - global text_decoded - - body = gen_req_body(apiname=api_name, APPId=APPId, file_path=file_path, featureId=featureId, featureInfo=featureInfo, dstFeatureId=dstFeatureId) - #request_url = 'https://ai-cloud.4paradigm.com:9443/sid/v1/private/s782b4996' - - #request_url = 'https://sut:80/sid/v1/private/s782b4996' - - #headers = {'content-type': "application/json", 'host': 'ai-cloud.4paradigm.com', 'appid': APPId} - - parsed_url = urlparse(sut_url) - headers = {'content-type': "application/json", 'host': parsed_url.hostname, 'appid': APPId} - - # 1. 首先测试服务健康检查 - response = requests.get(f"{sut_url}/health") - print(response.status_code, response.text) - - - # 请求头 - headers = {"Content-Type": "application/json"} - # 请求体(可指定限制处理的图片数量) - body = {"limit": 20 } # 可选参数,限制处理的图片总数 - - # 发送POST请求 - response = requests.post( - f"{sut_url}/v1/private/s782b4996", - data=json.dumps(body), - headers=headers - ) - - # 解析响应结果 - if response.status_code == 200: - result = response.json() - print("预测评估结果:") - print(f"准确率: {result['metrics']['accuracy']}%") - print(f"平均召回率: {result['metrics']['average_recall']}%") - print(f"处理图片总数: {result['metrics']['total_images']}") - else: - print(f"请求失败,状态码: {response.status_code}") - print(f"错误信息: {response.text}") - - - - - # 添加基本认证信息 - auth = ('llm', 'Rmf4#LcG(iFZrjU;2J') - #response = requests.post(request_url, data=json.dumps(body), headers=headers, auth=auth) - - #response = requests.post(sut_url + "/predict", data=json.dumps(body), headers=headers, auth=auth) - #response = requests.post(f"{sut_url}/sid/v1/private/s782b4996", data=json.dumps(body), headers=headers, auth=auth) - """ - response = requests.post(f"{sut_url}/v1/private/s782b4996", data=json.dumps(body), headers=headers) - """ - - - - - #print("HTTP状态码:", response.status_code) - #print("原始响应内容:", response.text) # 先打印原始内容 - #print(f"请求URL: {sut_url + '/v1/private/s782b4996'}") - #print(f"请求headers: {headers}") - #print(f"请求body: {body}") - - - - #tempResult = json.loads(response.content.decode('utf-8')) - #print(tempResult) - - """ - # 对text字段进行Base64解码 - if 'payload' in tempResult and 'updateFeatureRes' in tempResult['payload']: - text_encoded = tempResult['payload']['updateFeatureRes']['text'] - text_decoded = base64.b64decode(text_encoded).decode('utf-8') - print(f"Base64解码后的text字段内容: {text_decoded}") - """ - - #text_encoded = tempResult['payload']['updateFeatureRes']['text'] - #text_decoded = base64.b64decode(text_encoded).decode('utf-8') - #print(f"Base64解码后的text字段内容: {text_decoded}") - - - # 获取响应的 JSON 数据 - result = response.json() - with open(RESULT_FILEPATH, "w") as f: - json.dump(result, f, indent=4, ensure_ascii=False) - print(f"结果已成功写入 {RESULT_FILEPATH}") - -submit_config_filepath = os.getenv("SUBMIT_CONFIG_FILEPATH", "./tests/resources/submit_config") -result_filepath = os.getenv("RESULT_FILEPATH", "./out/result") -bad_cases_filepath = os.getenv("BAD_CASES_FILEPATH", "./out/badcase") -#detail_cases_filepath = os.getenv("DETAILED_CASES_FILEPATH", "./out/detailcase.jsonl") - -from typing import Any, Dict, List - -def result2file( - result: Dict[str, Any], - detail_cases: List[Dict[str, Any]] = None -): - assert result_filepath is not None - assert bad_cases_filepath is not None - #assert detailed_cases_filepath is not None - - if result is not None: - with open(result_filepath, "w") as f: - json.dump(result, f, indent=4, ensure_ascii=False) - #if LOCAL_TEST: - # logger.info(f'result:\n {json.dumps(result, indent=4)}') - """ - if detail_cases is not None: - with open(detailed_cases_filepath, "w") as f: - json.dump(detail_cases, f, indent=4, ensure_ascii=False) - if LOCAL_TEST: - logger.info(f'result:\n {json.dumps(detail_cases, indent=4)}') - """ - - -def test_image_prediction(sut_url, image_path): - """发送单张图片到服务端预测""" - url = f"{sut_url}/v1/private/s782b4996" - - try: - with open(image_path, 'rb') as f: - files = {'image': f} - response = requests.post(url, files=files, timeout=30) - - result = response.json() - if result.get('status') != 'success': - return None, f"服务端错误: {result.get('message')}" - - return result.get('top_prediction'), None - except Exception as e: - return None, f"请求错误: {str(e)}" - - - -import random -import time -#from tqdm import tqdm -import os -import requests - -if __name__ == '__main__': - - print(f"\n===== main开始请求接口 ===============================================") - # 1. 首先测试服务健康检查 - - print(f"\n===== 服务健康检查 ===================================================") - response = requests.get(f"{sut_url}/health") - print(response.status_code, response.text) - - """ - # 本地图片路径和真实标签(根据实际情况修改) - image_path = "/path/to/your/test_image.jpg" - true_label = "cat" # 图片的真实标签 - """ - - - """ - # 请求头 - headers = {"Content-Type": "application/json"} - # 请求体(可指定限制处理的图片数量) - body = {"limit": 20 } # 可选参数,限制处理的图片总数 - - # 发送POST请求 - response = requests.post( - f"{sut_url}/v1/private/s782b4996", - data=json.dumps(body), - headers=headers - ) - """ - - """ - # 读取图片文件 - with open(image_path, 'rb') as f: - files = {'image': f} - # 发送POST请求 - response = requests.post(f"{sut_url}/v1/private/s782b4996", files=files) - - - # 解析响应结果 - if response.status_code == 200: - result = response.json() - print("预测评估结果:") - print(f"准确率: {result['metrics']['accuracy']}%") - print(f"平均召回率: {result['metrics']['average_recall']}%") - print(f"处理图片总数: {result['metrics']['total_images']}") - else: - print(f"请求失败,状态码: {response.status_code}") - print(f"错误信息: {response.text}") - """ - - - ############################################################################################### - dataset_root = "/tmp/workspace/256ObjectCategoriesNew" # 数据集根目录 - samples_per_class = 3 # 每个类别抽取的样本数 - image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') # 支持的图片格式 - - # 结果统计变量 - #total_samples = 0 - correct_predictions = 0 - - # 结果统计变量 - total_samples = 0 - - # CUDA统计 - cuda_true_positives = 0 - cuda_false_positives = 0 - cuda_false_negatives = 0 - cuda_total_processing_time = 0.0 - - # CPU统计 - cpu_true_positives = 0 - cpu_false_positives = 0 - cpu_false_negatives = 0 - cpu_total_processing_time = 0.0 - - - - true_positives = 0 - false_positives = 0 - false_negatives = 0 - total_processing_time = 0.0 # 总处理时间(秒) - - """ - # 遍历所有类别文件夹 - for folder_name in tqdm(os.listdir(dataset_root), desc="处理类别"): - folder_path = os.path.join(dataset_root, folder_name) - - - # 提取类别名(从"序号.name"格式中提取name部分) - class_name = folder_name.split('.', 1)[1].strip().lower() - - # 获取文件夹中所有图片 - image_files = [] - for file in os.listdir(folder_path): - if file.lower().endswith(image_extensions): - image_files.append(os.path.join(folder_path, file)) - - # 随机抽取指定数量的图片(如果不足则取全部) - selected_images = random.sample( - image_files, - min(samples_per_class, len(image_files)) - ) - - # 处理选中的图片 - for img_path in selected_images: - total_count += 1 - - # 发送预测请求 - prediction, error = test_image_prediction(sut_url, img_path) - if error: - print(f"处理图片 {img_path} 失败: {error}") - continue - - # 解析预测结果 - pred_class = prediction.get('class_name', '').lower() - confidence = prediction.get('confidence', 0) - - # 判断是否预测正确(真实类别是否在预测类别中) - if class_name in pred_class: - correct_predictions += 1 - - - # 可选:打印详细结果 - print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") - """ - - # 遍历所有类别文件夹 - for folder_name in os.listdir(dataset_root): - folder_path = os.path.join(dataset_root, folder_name) - - # 跳过非文件夹的项目 - if not os.path.isdir(folder_path): - continue - - # 提取类别名(从"序号.name"格式中提取name部分) - try: - class_name = folder_name.split('.', 1)[1].strip().lower() - except IndexError: - print(f"警告:文件夹 {folder_name} 命名格式不正确,跳过该文件夹") - continue - - # 获取文件夹中所有图片 - image_files = [] - for file in os.listdir(folder_path): - file_path = os.path.join(folder_path, file) - if os.path.isfile(file_path) and file.lower().endswith(image_extensions): - image_files.append(file_path) - - # 随机抽取指定数量的图片(如果不足则取全部) - selected_images = random.sample( - image_files, - min(samples_per_class, len(image_files)) - ) - - # 处理该文件夹中的所有图片 - for img_path in selected_images: - total_samples += 1 - start_time = time.time() # 记录开始时间 - # 发送预测请求 - #prediction, error = test_image_prediction(sut_url, img_path) - - # 获取cuda和cpu的预测结果及处理时间 - cuda_pred, cpu_pred, error, processing_time = test_image_prediction(sut_url, img_path) - - # 计算单张图片处理时间(包括网络请求和模型预测) - #processing_time = time.time() - start_time - #total_processing_time += processing_time - - # 累加处理时间(单次请求的时间同时用于cuda和cpu统计) - cuda_total_processing_time += processing_time - cpu_total_processing_time += processing_time - - if error: - print(f"处理图片 {img_path} 失败: {error}") - # 处理失败的样本视为预测错误 - #false_negatives += 1 - # 处理失败时两种设备都记为错误 - cuda_false_negatives += 1 - cpu_false_negatives += 1 - continue - - # 解析预测结果 - #pred_class = prediction.get('class_name', '').lower() - #confidence = prediction.get('confidence', 0) - - # 判断是否预测正确(真实类别是否在预测类别中,不分大小写) - #is_correct = class_name in pred_class - - # 更新统计指标 - #if is_correct: - # true_positives += 1 - #else: - # false_positives += 1 - # false_negatives += 1 - - # 处理CUDA预测结果 - if cuda_pred: - cuda_pred_class = cuda_pred.get('class_name', '').lower() - cuda_is_correct = class_name in cuda_pred_class - - if cuda_is_correct: - cuda_true_positives += 1 - else: - cuda_false_positives += 1 - cuda_false_negatives += 1 - - # 处理CPU预测结果 - if cpu_pred: - cpu_pred_class = cpu_pred.get('class_name', '').lower() - cpu_is_correct = class_name in cpu_pred_class - - if cpu_is_correct: - cpu_true_positives += 1 - else: - cpu_false_positives += 1 - cpu_false_negatives += 1 - - # 打印详细结果(可选) - #print(f"图片: {os.path.basename(img_path)} | 真实: {class_name} | 预测: {pred_class} | 置信度: {confidence:.4f} | {'正确' if is_correct else '错误'}") - print(f"图片: {os.path.basename(img_path)} | 真实: {class_name}") - print(f"CUDA预测: {cuda_pred_class} | {'正确' if cuda_is_correct else '错误'}") - print(f"CPU预测: {cpu_pred_class} | {'正确' if cpu_is_correct else '错误'}\n") - - """ - # 计算整体指标(在单标签场景下,准确率=召回率) - if total_samples == 0: - overall_accuracy = 0.0 - overall_recall = 0.0 - else: - overall_accuracy = correct_predictions / total_samples - overall_recall = correct_predictions / total_samples # 整体召回率 - - # 输出统计结果 - print("\n" + "="*50) - print(f"测试总结:") - print(f"总测试样本数: {total_samples}") - print(f"正确预测样本数: {correct_predictions}") - print(f"整体准确率: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})") - print(f"整体召回率: {overall_recall:.4f} ({correct_predictions}/{total_samples})") - print("="*50) - """ - # 初始化结果字典 - """ - result = { - "total_processing_time": round(total_processing_time, 6), - "throughput": 0.0, - "accuracy": 0.0, - "recall": 0.0 - } - - # 计算评估指标 - if total_samples == 0: - print("没有找到任何图片样本") - - - # 准确率 = 正确预测的样本数 / 总预测样本数 - accuracy = true_positives / total_samples * 100 if total_samples > 0 else 0 - - # 召回率 = 正确预测的样本数 / (正确预测的样本数 + 未正确预测的正样本数) - recall_denominator = true_positives + false_negatives - recall = true_positives / recall_denominator * 100 if recall_denominator > 0 else 0 - - # 处理速度计算(每秒钟处理的图片张数) - # 避免除以0(当总时间极短时) - throughput = total_samples / total_processing_time if total_processing_time > 1e-6 else 0 - - # 更新结果字典 - result.update({ - "throughput": round(throughput, 6), - "accuracy": round(accuracy, 6), - "recall": round(recall, 6) - }) - """ - - - # 初始化结果字典 - result = { - # CUDA指标 - "cuda_total_processing_time": round(cuda_total_processing_time, 6), - "cuda_throughput": 0.0, - "cuda_accuracy": 0.0, - "cuda_recall": 0.0, - - # CPU指标 - "cpu_total_processing_time": round(cpu_total_processing_time, 6), - "cpu_throughput": 0.0, - "cpu_accuracy": 0.0, - "cpu_recall": 0.0, - - - } - - # 计算CUDA指标 - cuda_accuracy = cuda_true_positives / total_samples * 100 if total_samples > 0 else 0 - cuda_recall_denominator = cuda_true_positives + cuda_false_negatives - cuda_recall = cuda_true_positives / cuda_recall_denominator * 100 if cuda_recall_denominator > 0 else 0 - cuda_throughput = total_samples / cuda_total_processing_time if cuda_total_processing_time > 1e-6 else 0 - - # 计算CPU指标 - cpu_accuracy = cpu_true_positives / total_samples * 100 if total_samples > 0 else 0 - cpu_recall_denominator = cpu_true_positives + cpu_false_negatives - cpu_recall = cpu_true_positives / cpu_recall_denominator * 100 if cpu_recall_denominator > 0 else 0 - cpu_throughput = total_samples / cpu_total_processing_time if cpu_total_processing_time > 1e-6 else 0 - - # 更新结果字典 - result.update({ - # CUDA指标 - "cuda_throughput": round(cuda_throughput, 6), - "cuda_accuracy": round(cuda_accuracy, 6), - "cuda_recall": round(cuda_recall, 6), - - # CPU指标 - "cpu_throughput": round(cpu_throughput, 6), - "cpu_accuracy": round(cpu_accuracy, 6), - "cpu_recall": round(cpu_recall, 6) - }) - - # 打印最终统计结果 - print("\n" + "="*50) - print(f"总样本数: {total_samples}") - - print("\nCUDA 统计:") - print(f"总处理时间: {cuda_total_processing_time:.4f}秒") - print(f"处理速度: {result['cuda_throughput']:.2f}张/秒") - print(f"正确预测: {cuda_true_positives}") - print(f"准确率: {result['cuda_accuracy']:.4f}%") - print(f"召回率: {result['cuda_recall']:.4f}%") - - print("\nCPU 统计:") - print(f"总处理时间: {cpu_total_processing_time:.4f}秒") - print(f"处理速度: {result['cpu_throughput']:.2f}张/秒") - print(f"正确预测: {cpu_true_positives}") - print(f"准确率: {result['cpu_accuracy']:.4f}%") - print(f"召回率: {result['cpu_recall']:.4f}%") - print("="*50) - - - #result = {} - #result['accuracy_1_1'] = 3 - result2file(result) - - """ - if result['accuracy_1_1'] < 0.9: - log.error(f"1:1正确率未达到90%, 视为产品不可用") - change_product_unavailable() - - - if result['accuracy_1_N'] < 1: - log.error(f"1:N正确率未达到100%, 视为产品不可用") - change_product_unavailable() - if result['1_1_latency'] > 0.5: - log.error(f"1:1平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['1_N_latency'] > 0.5: - log.error(f"1:N平均latency超过0.5s, 视为产品不可用") - change_product_unavailable() - if result['enroll_latency'] > 1: - log.error(f"enroll(入库)平均latency超过1s, 视为产品不可用") - change_product_unavailable() - """ - exit_code = 0 - -