Files
enginex-bi_series-vc-cnn/model_test_caltech_3.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

215 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import BeitImageProcessor, BeitForImageClassification
# 根据模型实际架构选择类
from transformers import ViTForImageClassification, BeitForImageClassification
from tqdm import tqdm
from transformers import AutoConfig
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
import random
import time # 新增导入时间模块
# 支持 Iluvatar GPU 加速,若不可用则使用 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}") # 添加调试信息
# 若有多块 GPU可使用 DataParallel 进行并行计算
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块 GPU 进行计算")
class COCOImageClassifier:
def __init__(self, model_path: str, local_image_paths: list):
"""初始化COCO图像分类器"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到设备
self.model = self.model.to(device)
print(f"模型是否在 GPU 上: {next(self.model.parameters()).is_cuda}") # 添加调试信息
# 若有多块 GPU使用 DataParallel
if torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
self.local_image_paths = local_image_paths
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""
预测本地图片文件对应的图片类别
Args:
image_path: 本地图片文件路径
top_k: 返回置信度最高的前k个类别
Returns:
包含预测结果的字典
"""
try:
# 打开图片
image = Image.open(image_path).convert("RGB")
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到设备
inputs = inputs.to(device)
# 模型推理
with torch.no_grad():
outputs = self.model(**inputs)
# 获取预测结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
# 整理结果
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
confidence = top_probs[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片文件 {image_path} 时出错: {e}")
return None
def batch_predict(self, limit: int = 20, top_k: int = 5) -> list:
"""
批量预测本地图片
Args:
limit: 限制处理的图片数量
top_k: 返回置信度最高的前k个类别
Returns:
包含所有预测结果的列表
"""
results = []
local_image_paths = self.local_image_paths[:limit]
print(f"开始预测 {len(local_image_paths)} 张本地图片...")
start_time = time.time() # 记录开始时间
for image_path in tqdm(local_image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time() # 记录结束时间
total_time = end_time - start_time # 计算总时间
images_per_second = len(results) / total_time # 计算每秒处理的图片数量
print(f"模型每秒可以处理 {images_per_second:.2f} 张图片")
return results
def save_results(self, results: list, output_file: str = "caltech_predictions.json"):
"""
保存预测结果到JSON文件
Args:
results: 预测结果列表
output_file: 输出文件名
"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"结果已保存到 {output_file}")
# 主程序
if __name__ == "__main__":
# 替换为本地模型路径
LOCAL_MODEL_PATH = "/home/zhoushasha/models/microsoft_beit_base_patch16_224_pt22k_ft22k"
# 替换为Caltech 256数据集文件夹路径
CALTECH_256_PATH = "/home/zhoushasha/models/256ObjectCategoriesNew"
local_image_paths = []
true_labels = {}
# 遍历Caltech 256数据集中的每个文件夹
for folder in os.listdir(CALTECH_256_PATH):
folder_path = os.path.join(CALTECH_256_PATH, folder)
if os.path.isdir(folder_path):
# 获取文件夹名称中的类别名称
class_name = folder.split('.', 1)[1]
# 获取文件夹中的所有图片文件
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 随机选择3张图片
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 创建分类器实例
classifier = COCOImageClassifier(LOCAL_MODEL_PATH, local_image_paths)
# 批量预测
results = classifier.batch_predict(limit=len(local_image_paths), top_k=3)
# 保存结果
classifier.save_results(results)
# 打印简要统计
print(f"\n处理完成: 成功预测 {len(results)} 张图片")
if results:
print("\n示例预测结果:")
sample = results[0]
print(f"图片路径: {sample['image_path']}")
for i, pred in enumerate(sample['predictions'], 1):
print(f"{i}. {pred['class_name']} (置信度: {pred['confidence']:.2%})")
correct_count = 0
total_count = len(results)
# 统计每个类别的实际样本数和正确预测数
class_actual_count = {}
class_correct_count = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path).lower()
# 统计每个类别的实际样本数
if true_class not in class_actual_count:
class_actual_count[true_class] = 0
class_actual_count[true_class] += 1
# 检查预测类别中的每个单词是否包含真实标签
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
# 统计每个类别的正确预测数
if true_class not in class_correct_count:
class_correct_count[true_class] = 0
class_correct_count[true_class] += 1
break
accuracy = correct_count / total_count
print(f"\nAccuracy: {accuracy * 100:.2f}%")
# 计算每个类别的召回率
recall_per_class = {}
for class_name in class_actual_count:
if class_name in class_correct_count:
recall_per_class[class_name] = class_correct_count[class_name] / class_actual_count[class_name]
else:
recall_per_class[class_name] = 0
# 计算平均召回率
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
print(f"\nAverage Recall: {average_recall * 100:.2f}%")