Files
enginex-bi_series-vc-cnn/model_test_caltech_http_1.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

163 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
from tqdm import tqdm
import os
import random
import time
from flask import Flask, request, jsonify # 引入Flask
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}")
class COCOImageClassifier:
def __init__(self, model_path: str):
"""初始化分类器移除local_image_paths参数改为动态接收"""
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
self.model = self.model.to(device)
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 块GPU")
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
"""预测单张图片(复用原逻辑)"""
try:
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.model(** inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(top_k, dim=1)
predictions = []
for i in range(top_k):
class_idx = top_indices[0, i].item()
predictions.append({
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": top_probs[0, i].item()
})
return {
"image_path": image_path,
"predictions": predictions
}
except Exception as e:
print(f"处理图片 {image_path} 出错: {e}")
return None
def batch_predict_and_evaluate(self, image_paths: list, true_labels: dict, top_k: int = 3) -> dict:
"""批量预测并计算准确率、召回率"""
results = []
start_time = time.time()
for image_path in tqdm(image_paths):
result = self.predict_image_path(image_path, top_k)
if result:
results.append(result)
end_time = time.time()
total_time = end_time - start_time
images_per_second = len(results) / total_time if total_time > 0 else 0
# 计算准确率和召回率(复用原逻辑)
correct_count = 0
total_count = len(results)
class_actual_count = {}
class_correct_count = {}
for prediction in results:
image_path = prediction['image_path']
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
predicted_class = top1_prediction['class_name'].lower()
true_class = true_labels.get(image_path, "").lower()
# 统计每个类别的实际样本数
class_actual_count[true_class] = class_actual_count.get(true_class, 0) + 1
# 检查预测是否正确
words = predicted_class.split()
for word in words:
if true_class in word:
correct_count += 1
class_correct_count[true_class] = class_correct_count.get(true_class, 0) + 1
break
# 计算指标
accuracy = correct_count / total_count if total_count > 0 else 0
recall_per_class = {}
for class_name in class_actual_count:
recall_per_class[class_name] = class_correct_count.get(class_name, 0) / class_actual_count[class_name]
average_recall = sum(recall_per_class.values()) / len(recall_per_class) if recall_per_class else 0
# 返回包含指标的结果
return {
"status": "success",
"metrics": {
"accuracy": round(accuracy * 100, 2), # 百分比
"average_recall": round(average_recall * 100, 2), # 百分比
"total_images": total_count,
"correct_predictions": correct_count,
"speed_images_per_second": round(images_per_second, 2)
},
"sample_predictions": results[:3] # 示例预测结果(可选)
}
# 初始化Flask服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 容器内模型路径
DATASET_PATH = os.environ.get("DATASET_PATH", "/app/dataset") # 容器内数据集路径
classifier = COCOImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def evaluate():
"""接收请求并返回评估结果(准确率、召回率等)"""
try:
# 解析请求参数可选允许动态指定limit等参数
data = request.get_json()
limit = data.get("limit", 20) # 限制处理的图片数量
# 加载数据集(容器内路径)
local_image_paths = []
true_labels = {}
for folder in os.listdir(DATASET_PATH):
folder_path = os.path.join(DATASET_PATH, folder)
if os.path.isdir(folder_path):
class_name = folder.split('.', 1)[1]
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
selected_images = random.sample(image_files, min(3, len(image_files)))
for image_path in selected_images:
local_image_paths.append(image_path)
true_labels[image_path] = class_name
# 限制处理数量
local_image_paths = local_image_paths[:limit]
# 执行预测和评估
result = classifier.batch_predict_and_evaluate(local_image_paths, true_labels, top_k=3)
return jsonify(result)
except Exception as e:
return jsonify({
"status": "error",
"message": str(e)
}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "device": str(device)}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8000, debug=False)