Files
enginex-bi_series-vc-cnn/model_test_caltech_3.py

215 lines
7.9 KiB
Python
Raw Normal View History

2025-08-06 15:38:55 +08:00
import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import BeitImageProcessor, BeitForImageClassification
# 根据模型实际架构选择类
from transformers import ViTForImageClassification, BeitForImageClassification
from tqdm import tqdm
from transformers import AutoConfig
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
import random
import time # 新增导入时间模块
# 支持 Iluvatar GPU 加速,若不可用则使用 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}") # 添加调试信息
# 若有多块 GPU可使用 DataParallel 进行并行计算
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块 GPU 进行计算")
class COCOImageClassifier:
def __init__(self, model_path: str, local_image_paths: list):
"""初始化COCO图像分类器"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到设备
self.model = self.model.to(device)
print(f"模型是否在 GPU 上: {next(self.model.parameters()).is_cuda}") # 添加调试信息
# 若有多块 GPU使用 DataParallel
if torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
self.local_image_paths = local_image_paths
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""
预测本地图片文件对应的图片类别
Args:
image_path: 本地图片文件路径
top_k: 返回置信度最高的前k个类别
Returns:
包含预测结果的字典
"""
try:
# 打开图片
image = Image.open(image_path).convert("RGB")
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到设备
inputs = inputs.to(device)
# 模型推理
with torch.no_grad():
outputs = self.model(**inputs)
# 获取预测结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
# 整理结果
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
confidence = top_probs[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片文件 {image_path} 时出错: {e}")
return None
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
"""
批量预测本地图片
Args:
limit: 限制处理的图片数量
top_k: 返回置信度最高的前k个类别
Returns:
包含所有预测结果的列表
"""
results = []
local_image_paths = self.local_image_paths[:limit]
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
start_time = time.time() # 记录开始时间
for image_path in tqdm(local_image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time() # 记录结束时间
total_time = end_time - start_time # 计算总时间
images_per_second = len(results) / total_time # 计算每秒处理的图片数量
print(f"模型每秒可以处理 {images_per_second:.2f} 张图片")
return results
def save_results(self, results: list, output_file: str = "caltech_predictions.json"):
"""
保存预测结果到JSON文件
Args:
results: 预测结果列表
output_file: 输出文件名
"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"结果已保存到 {output_file}")
# 主程序
if __name__ == "__main__":
# 替换为本地模型路径
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
# 替换为Caltech 256数据集文件夹路径
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
local_image_paths = []
true_labels = {}
# 遍历Caltech 256数据集中的每个文件夹
for folder in os.listdir(CALTECH_256_PATH):
folder_path = os.path.join(CALTECH_256_PATH, folder)
if os.path.isdir(folder_path):
# 获取文件夹名称中的类别名称
class_name = folder.split('.', 1)[1]
# 获取文件夹中的所有图片文件
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 随机选择3张图片
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 创建分类器实例
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
# 批量预测
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
# 保存结果
classifier.save_results(results)
# 打印简要统计
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
if results:
print("\n示例预测结果:")
sample = results[0]
print(f"图片路径: {sample['image_path']}")
for i, pred in enumerate(sample['predictions'], 1):
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
correct_count = 0
total_count = len(results)
# 统计每个类别的实际样本数和正确预测数
class_actual_count = {}
class_correct_count = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path).lower()
# 统计每个类别的实际样本数
if true_class not in class_actual_count:
class_actual_count[true_class] = 0
class_actual_count[true_class] += 1
# 检查预测类别中的每个单词是否包含真实标签
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
# 统计每个类别的正确预测数
if true_class not in class_correct_count:
class_correct_count[true_class] = 0
class_correct_count[true_class] += 1
break
accuracy = correct_count / total_count
print(f"\nAccuracy: {accuracy * 100:.2f}%")
# 计算每个类别的召回率
recall_per_class = {}
for class_name in class_actual_count:
if class_name in class_correct_count:
recall_per_class[class_name] = class_correct_count[class_name] / class_actual_count[class_name]
else:
recall_per_class[class_name] = 0
# 计算平均召回率
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
print(f"\nAverage Recall: {average_recall * 100:.2f}%")