163 lines
6.4 KiB
Python
163 lines
6.4 KiB
Python
import requests
|
||
import json
|
||
import torch
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||
from tqdm import tqdm
|
||
import os
|
||
import random
|
||
import time
|
||
from flask import Flask, request, jsonify # 引入Flask
|
||
|
||
# 设备配置
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"当前使用的设备: {device}")
|
||
|
||
class COCOImageClassifier:
|
||
def __init__(self, model_path: str):
|
||
"""初始化分类器(移除local_image_paths参数,改为动态接收)"""
|
||
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
||
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
||
self.model = self.model.to(device)
|
||
|
||
if torch.cuda.device_count() > 1:
|
||
print(f"使用 {torch.cuda.device_count()} 块GPU")
|
||
self.model = torch.nn.DataParallel(self.model)
|
||
|
||
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
||
|
||
def predict_image_path(self, image_path: str, top_k: int = 5) -> dict:
|
||
"""预测单张图片(复用原逻辑)"""
|
||
try:
|
||
image = Image.open(image_path).convert("RGB")
|
||
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(** inputs)
|
||
|
||
logits = outputs.logits
|
||
probs = torch.nn.functional.softmax(logits, dim=1)
|
||
top_probs, top_indices = probs.topk(top_k, dim=1)
|
||
|
||
predictions = []
|
||
for i in range(top_k):
|
||
class_idx = top_indices[0, i].item()
|
||
predictions.append({
|
||
"class_id": class_idx,
|
||
"class_name": self.id2label[class_idx],
|
||
"confidence": top_probs[0, i].item()
|
||
})
|
||
|
||
return {
|
||
"image_path": image_path,
|
||
"predictions": predictions
|
||
}
|
||
except Exception as e:
|
||
print(f"处理图片 {image_path} 出错: {e}")
|
||
return None
|
||
|
||
def batch_predict_and_evaluate(self, image_paths: list, true_labels: dict, top_k: int = 3) -> dict:
|
||
"""批量预测并计算准确率、召回率"""
|
||
results = []
|
||
start_time = time.time()
|
||
|
||
for image_path in tqdm(image_paths):
|
||
result = self.predict_image_path(image_path, top_k)
|
||
if result:
|
||
results.append(result)
|
||
|
||
end_time = time.time()
|
||
total_time = end_time - start_time
|
||
images_per_second = len(results) / total_time if total_time > 0 else 0
|
||
|
||
# 计算准确率和召回率(复用原逻辑)
|
||
correct_count = 0
|
||
total_count = len(results)
|
||
class_actual_count = {}
|
||
class_correct_count = {}
|
||
|
||
for prediction in results:
|
||
image_path = prediction['image_path']
|
||
top1_prediction = max(prediction['predictions'], key=lambda x: x['confidence'])
|
||
predicted_class = top1_prediction['class_name'].lower()
|
||
true_class = true_labels.get(image_path, "").lower()
|
||
|
||
# 统计每个类别的实际样本数
|
||
class_actual_count[true_class] = class_actual_count.get(true_class, 0) + 1
|
||
|
||
# 检查预测是否正确
|
||
words = predicted_class.split()
|
||
for word in words:
|
||
if true_class in word:
|
||
correct_count += 1
|
||
class_correct_count[true_class] = class_correct_count.get(true_class, 0) + 1
|
||
break
|
||
|
||
# 计算指标
|
||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||
recall_per_class = {}
|
||
for class_name in class_actual_count:
|
||
recall_per_class[class_name] = class_correct_count.get(class_name, 0) / class_actual_count[class_name]
|
||
|
||
average_recall = sum(recall_per_class.values()) / len(recall_per_class) if recall_per_class else 0
|
||
|
||
# 返回包含指标的结果
|
||
return {
|
||
"status": "success",
|
||
"metrics": {
|
||
"accuracy": round(accuracy * 100, 2), # 百分比
|
||
"average_recall": round(average_recall * 100, 2), # 百分比
|
||
"total_images": total_count,
|
||
"correct_predictions": correct_count,
|
||
"speed_images_per_second": round(images_per_second, 2)
|
||
},
|
||
"sample_predictions": results[:3] # 示例预测结果(可选)
|
||
}
|
||
|
||
# 初始化Flask服务
|
||
app = Flask(__name__)
|
||
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 容器内模型路径
|
||
DATASET_PATH = os.environ.get("DATASET_PATH", "/app/dataset") # 容器内数据集路径
|
||
classifier = COCOImageClassifier(MODEL_PATH)
|
||
|
||
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||
def evaluate():
|
||
"""接收请求并返回评估结果(准确率、召回率等)"""
|
||
try:
|
||
# 解析请求参数(可选:允许动态指定limit等参数)
|
||
data = request.get_json()
|
||
limit = data.get("limit", 20) # 限制处理的图片数量
|
||
|
||
# 加载数据集(容器内路径)
|
||
local_image_paths = []
|
||
true_labels = {}
|
||
for folder in os.listdir(DATASET_PATH):
|
||
folder_path = os.path.join(DATASET_PATH, folder)
|
||
if os.path.isdir(folder_path):
|
||
class_name = folder.split('.', 1)[1]
|
||
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
||
selected_images = random.sample(image_files, min(3, len(image_files)))
|
||
for image_path in selected_images:
|
||
local_image_paths.append(image_path)
|
||
true_labels[image_path] = class_name
|
||
|
||
# 限制处理数量
|
||
local_image_paths = local_image_paths[:limit]
|
||
|
||
# 执行预测和评估
|
||
result = classifier.batch_predict_and_evaluate(local_image_paths, true_labels, top_k=3)
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": str(e)
|
||
}), 500
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
return jsonify({"status": "healthy", "device": str(device)}), 200
|
||
|
||
if __name__ == "__main__":
|
||
app.run(host='0.0.0.0', port=8000, debug=False) |