import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification import os from flask import Flask, request, jsonify from io import BytesIO # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"当前使用的设备: {device}") class ImageClassifier: def __init__(self, model_path: str): self.processor = AutoImageProcessor.from_pretrained(model_path) self.model = AutoModelForImageClassification.from_pretrained(model_path) self.model = self.model.to(device) if torch.cuda.device_count() > 1: print(f"使用 {torch.cuda.device_count()} 块GPU") self.model = torch.nn.DataParallel(self.model) self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label def predict_single_image(self, image) -> dict: """预测单张图片,返回置信度最高的结果""" try: # 处理图片 inputs = self.processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1) # 获取置信度最高的预测结果 max_prob, max_idx = probs.max(dim=1) class_idx = max_idx.item() return { "status": "success", "top_prediction": { "class_id": class_idx, "class_name": self.id2label[class_idx], "confidence": max_prob.item() } } except Exception as e: return { "status": "error", "message": str(e) } # 初始化服务 app = Flask(__name__) MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径) classifier = ImageClassifier(MODEL_PATH) @app.route('/v1/private/s782b4996', methods=['POST']) def predict_single(): """接收单张图片并返回最高置信度预测结果""" # 检查是否有图片上传 if 'image' not in request.files: return jsonify({"status": "error", "message": "请求中未包含图片"}), 400 image_file = request.files['image'] try: # 读取图片 image = Image.open(BytesIO(image_file.read())).convert("RGB") # 预测 result = classifier.predict_single_image(image) return jsonify(result) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/health', methods=['GET']) def health_check(): return jsonify({"status": "healthy", "device": str(device)}), 200 if __name__ == "__main__": app.run(host='0.0.0.0', port=80, debug=False)