import requests import json import torch from PIL import Image from io import BytesIO from transformers import AutoImageProcessor, AutoModelForImageClassification from tqdm import tqdm import os import random import time # 强制使用CPU device = torch.device("cpu") print(f"当前使用的设备: {device}") class COCOImageClassifier: def __init__(self, model_path: str, local_image_paths: list): """初始化COCO图像分类器""" self.processor = AutoImageProcessor.from_pretrained(model_path) self.model = AutoModelForImageClassification.from_pretrained(model_path) # 将模型移动到CPU self.model = self.model.to(device) self.id2label = self.model.config.id2label self.local_image_paths = local_image_paths def predict_image_path(self, image_path: str, top_k: int = 5) -> dict: """ 预测本地图片文件对应的图片类别 Args: image_path: 本地图片文件路径 top_k: 返回置信度最高的前k个类别 Returns: 包含预测结果的字典 """ try: # 打开图片 image = Image.open(image_path).convert("RGB") # 预处理 inputs = self.processor(images=image, return_tensors="pt") # 将输入数据移动到CPU inputs = inputs.to(device) # 模型推理 with torch.no_grad(): outputs = self.model(**inputs) # 获取预测结果 logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1) top_probs, top_indices = probs.topk(top_k, dim=1) # 整理结果 predictions = [] for i in range(top_k): class_idx = top_indices[0, i].item() confidence = top_probs[0, i].item() predictions.append({ "class_id": class_idx, "class_name": self.id2label[class_idx], "confidence": confidence }) return { "image_path": image_path, "predictions": predictions } except Exception as e: print(f"处理图片文件 {image_path} 时出错: {e}") return None def batch_predict(self, limit: int = 20, top_k: int = 5) -> list: """ 批量预测本地图片 Args: limit: 限制处理的图片数量 top_k: 返回置信度最高的前k个类别 Returns: 包含所有预测结果的列表 """ results = [] local_image_paths = self.local_image_paths[:limit] print(f"开始预测 {len(local_image_paths)} 张本地图片...") start_time = time.time() for image_path in tqdm(local_image_paths): result = self.predict_image_path(image_path, top_k) if result: results.append(result) end_time = time.time() # 计算吞吐量 throughput = len(results) / (end_time - start_time) print(f"模型每秒可以处理 {throughput:.2f} 张图片") return results def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"): """ 保存预测结果到JSON文件 Args: results: 预测结果列表 output_file: 输出文件名 """ with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"结果已保存到 {output_file}") # 主程序 if __name__ == "__main__": # 替换为本地模型路径 LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k" # 替换为Caltech 256数据集文件夹路径 New CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew" local_image_paths = [] true_labels = {} # 遍历Caltech 256数据集中的每个文件夹 for folder in os.listdir(CALTECH_256_PATH): folder_path = os.path.join(CALTECH_256_PATH, folder) if os.path.isdir(folder_path): # 获取文件夹名称中的类别名称 class_name = folder.split('.', 1)[1] # 获取文件夹中的所有图片文件 image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))] # 随机选择3张图片 selected_images = random.sample(image_files, min(3, len(image_files))) for image_path in selected_images: local_image_paths.append(image_path) true_labels[image_path] = class_name # 创建分类器实例 classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths) # 批量预测 results = classifier.batch_predict(limit=len(local_image_paths), top_k=3) # 保存结果 classifier.save_results(results) # 打印简要统计 print(f"\n处理完成: 成功预测 {len(results)} 张图片") if results: print("\n示例预测结果:") sample = results[0] print(f"图片路径: {sample['image_path']}") for i, pred in enumerate(sample['predictions'], 1): print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})") correct_count = 0 total_count = len(results) class_true_positives = {} class_false_negatives = {} for prediction in results: image_path = prediction['image_path'] top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence']) predicted_class = top1_prediction['class_name'].lower() true_class = true_labels.get(image_path).lower() if true_class not in class_true_positives: class_true_positives[true_class] = 0 class_false_negatives[true_class] = 0 # 检查预测类别中的每个单词是否包含真实标签 words = predicted_class.split() for word in words: if true_class in word: correct_count += 1 class_true_positives[true_class] += 1 break else: class_false_negatives[true_class] += 1 accuracy = correct_count / total_count print(f"\nAccuracy: {accuracy * 100:.2f}%") # 计算召回率 total_true_positives = 0 total_false_negatives = 0 for class_name in class_true_positives: total_true_positives += class_true_positives[class_name] total_false_negatives += class_false_negatives[class_name] recall = total_true_positives / (total_true_positives + total_false_negatives) print(f"Recall: {recall * 100:.2f}%")