Files
enginex-bi_series-vc-cnn/model_test_caltech_cpu1.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

197 lines
6.6 KiB
Python

import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
from tqdm import tqdm
import os
import random
import time
# 强制使用CPU
device = torch.device("cpu")
print(f"当前使用的设备: {device}")
class COCOImageClassifier:
def __init__(self, model_path: str, local_image_paths: list):
"""初始化COCO图像分类器"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到CPU
self.model = self.model.to(device)
self.id2label = self.model.config.id2label
self.local_image_paths = local_image_paths
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""
预测本地图片文件对应的图片类别
Args:
image_path: 本地图片文件路径
top_k: 返回置信度最高的前k个类别
Returns:
包含预测结果的字典
"""
try:
# 打开图片
image = Image.open(image_path).convert("RGB")
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到CPU
inputs = inputs.to(device)
# 模型推理
with torch.no_grad():
outputs = self.model(**inputs)
# 获取预测结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
# 整理结果
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
confidence = top_probs[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片文件 {image_path} 时出错: {e}")
return None
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
"""
批量预测本地图片
Args:
limit: 限制处理的图片数量
top_k: 返回置信度最高的前k个类别
Returns:
包含所有预测结果的列表
"""
results = []
local_image_paths = self.local_image_paths[:limit]
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
start_time = time.time()
for image_path in tqdm(local_image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time()
# 计算吞吐量
throughput = len(results) / (end_time - start_time)
print(f"模型每秒可以处理 {throughput:.2f} 张图片")
return results
def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"):
"""
保存预测结果到JSON文件
Args:
results: 预测结果列表
output_file: 输出文件名
"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"结果已保存到 {output_file}")
# 主程序
if __name__ == "__main__":
# 替换为本地模型路径
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
# 替换为Caltech 256数据集文件夹路径 New
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
local_image_paths = []
true_labels = {}
# 遍历Caltech 256数据集中的每个文件夹
for folder in os.listdir(CALTECH_256_PATH):
folder_path = os.path.join(CALTECH_256_PATH, folder)
if os.path.isdir(folder_path):
# 获取文件夹名称中的类别名称
class_name = folder.split('.', 1)[1]
# 获取文件夹中的所有图片文件
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 随机选择3张图片
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 创建分类器实例
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
# 批量预测
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
# 保存结果
classifier.save_results(results)
# 打印简要统计
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
if results:
print("\n示例预测结果:")
sample = results[0]
print(f"图片路径: {sample['image_path']}")
for i, pred in enumerate(sample['predictions'], 1):
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
correct_count = 0
total_count = len(results)
class_true_positives = {}
class_false_negatives = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path).lower()
if true_class not in class_true_positives:
class_true_positives[true_class] = 0
class_false_negatives[true_class] = 0
# 检查预测类别中的每个单词是否包含真实标签
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
class_true_positives[true_class] += 1
break
else:
class_false_negatives[true_class] += 1
accuracy = correct_count / total_count
print(f"\nAccuracy: {accuracy * 100:.2f}%")
# 计算召回率
total_true_positives = 0
total_false_negatives = 0
for class_name in class_true_positives:
total_true_positives += class_true_positives[class_name]
total_false_negatives += class_false_negatives[class_name]
recall = total_true_positives / (total_true_positives + total_false_negatives)
print(f"Recall: {recall * 100:.2f}%")