197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
import requests
|
|
import json
|
|
import torch
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|
from tqdm import tqdm
|
|
import os
|
|
import random
|
|
import time
|
|
|
|
# 强制使用CPU
|
|
device = torch.device("cpu")
|
|
print(f"当前使用的设备: {device}")
|
|
|
|
class COCOImageClassifier:
|
|
def __init__(self, model_path: str, local_image_paths: list):
|
|
"""初始化COCO图像分类器"""
|
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
|
|
|
# 将模型移动到CPU
|
|
self.model = self.model.to(device)
|
|
self.id2label = self.model.config.id2label
|
|
self.local_image_paths = local_image_paths
|
|
|
|
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
|
|
"""
|
|
预测本地图片文件对应的图片类别
|
|
|
|
Args:
|
|
image_path: 本地图片文件路径
|
|
top_k: 返回置信度最高的前k个类别
|
|
|
|
Returns:
|
|
包含预测结果的字典
|
|
"""
|
|
try:
|
|
# 打开图片
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
# 预处理
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|
|
|
# 将输入数据移动到CPU
|
|
inputs = inputs.to(device)
|
|
|
|
# 模型推理
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
|
|
# 获取预测结果
|
|
logits = outputs.logits
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
top_probs, top_indices = probs.topk(top_k, dim=1)
|
|
|
|
# 整理结果
|
|
predictions = []
|
|
for i in range(top_k):
|
|
class_idx = top_indices[0, i].item()
|
|
confidence = top_probs[0, i].item()
|
|
predictions.append({
|
|
"class_id": class_idx,
|
|
"class_name": self.id2label[class_idx],
|
|
"confidence": confidence
|
|
})
|
|
|
|
return {
|
|
"image_path": image_path,
|
|
"predictions": predictions
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"处理图片文件 {image_path} 时出错: {e}")
|
|
return None
|
|
|
|
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
|
|
"""
|
|
批量预测本地图片
|
|
|
|
Args:
|
|
limit: 限制处理的图片数量
|
|
top_k: 返回置信度最高的前k个类别
|
|
|
|
Returns:
|
|
包含所有预测结果的列表
|
|
"""
|
|
results = []
|
|
local_image_paths = self.local_image_paths[:limit]
|
|
|
|
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
|
|
start_time = time.time()
|
|
for image_path in tqdm(local_image_paths):
|
|
result = self.predict_image_path(image_path, top_k)
|
|
if result:
|
|
results.append(result)
|
|
end_time = time.time()
|
|
|
|
# 计算吞吐量
|
|
throughput = len(results) / (end_time - start_time)
|
|
print(f"模型每秒可以处理 {throughput:.2f} 张图片")
|
|
|
|
return results
|
|
|
|
def save_results(self, results: list, output_file: str = "celtech_cpu_predictions.json"):
|
|
"""
|
|
保存预测结果到JSON文件
|
|
|
|
Args:
|
|
results: 预测结果列表
|
|
output_file: 输出文件名
|
|
"""
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
json.dump(results, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"结果已保存到 {output_file}")
|
|
|
|
# 主程序
|
|
if __name__ == "__main__":
|
|
# 替换为本地模型路径
|
|
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
|
|
|
|
# 替换为Caltech 256数据集文件夹路径 New
|
|
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
|
|
|
|
local_image_paths = []
|
|
true_labels = {}
|
|
|
|
# 遍历Caltech 256数据集中的每个文件夹
|
|
for folder in os.listdir(CALTECH_256_PATH):
|
|
folder_path = os.path.join(CALTECH_256_PATH, folder)
|
|
if os.path.isdir(folder_path):
|
|
# 获取文件夹名称中的类别名称
|
|
class_name = folder.split('.', 1)[1]
|
|
# 获取文件夹中的所有图片文件
|
|
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
|
# 随机选择3张图片
|
|
selected_images = random.sample(image_files, min(3, len(image_files)))
|
|
for image_path in selected_images:
|
|
local_image_paths.append(image_path)
|
|
true_labels[image_path] = class_name
|
|
|
|
# 创建分类器实例
|
|
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
|
|
|
|
# 批量预测
|
|
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
|
|
|
|
# 保存结果
|
|
classifier.save_results(results)
|
|
|
|
# 打印简要统计
|
|
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
|
|
if results:
|
|
print("\n示例预测结果:")
|
|
sample = results[0]
|
|
print(f"图片路径: {sample['image_path']}")
|
|
for i, pred in enumerate(sample['predictions'], 1):
|
|
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
|
|
|
|
correct_count = 0
|
|
total_count = len(results)
|
|
class_true_positives = {}
|
|
class_false_negatives = {}
|
|
|
|
for prediction in results:
|
|
image_path = prediction['image_path']
|
|
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
|
|
predicted_class = top1_prediction['class_name'].lower()
|
|
true_class = true_labels.get(image_path).lower()
|
|
|
|
if true_class not in class_true_positives:
|
|
class_true_positives[true_class] = 0
|
|
class_false_negatives[true_class] = 0
|
|
|
|
# 检查预测类别中的每个单词是否包含真实标签
|
|
words = predicted_class.split()
|
|
for word in words:
|
|
if true_class in word:
|
|
correct_count += 1
|
|
class_true_positives[true_class] += 1
|
|
break
|
|
else:
|
|
class_false_negatives[true_class] += 1
|
|
|
|
accuracy = correct_count / total_count
|
|
print(f"\nAccuracy: {accuracy * 100:.2f}%")
|
|
|
|
# 计算召回率
|
|
total_true_positives = 0
|
|
total_false_negatives = 0
|
|
for class_name in class_true_positives:
|
|
total_true_positives += class_true_positives[class_name]
|
|
total_false_negatives += class_false_negatives[class_name]
|
|
|
|
recall = total_true_positives / (total_true_positives + total_false_negatives)
|
|
print(f"Recall: {recall * 100:.2f}%") |