89 lines
3.4 KiB
Python
89 lines
3.4 KiB
Python
import torch
|
|
from PIL import Image
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|
import os
|
|
from flask import Flask, request, jsonify
|
|
from io import BytesIO
|
|
|
|
# 设备配置
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"当前使用的设备: {device}")
|
|
|
|
class ImageClassifier:
|
|
def __init__(self, model_path: str):
|
|
# 获取模型路径下的第一个子目录(假设模型文件存放在这里)
|
|
subdirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))]
|
|
if not subdirs:
|
|
raise ValueError(f"在 {model_path} 下未找到任何子目录,无法加载模型")
|
|
|
|
# 实际的模型文件路径
|
|
actual_model_path = os.path.join(model_path, subdirs[0])
|
|
print(f"加载模型从: {actual_model_path}")
|
|
|
|
self.processor = AutoImageProcessor.from_pretrained(actual_model_path)
|
|
self.model = AutoModelForImageClassification.from_pretrained(actual_model_path)
|
|
self.model = self.model.to(device)
|
|
|
|
if torch.cuda.device_count() > 1:
|
|
print(f"使用 {torch.cuda.device_count()} 块GPU")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
|
|
|
def predict_single_image(self, image) -> dict:
|
|
"""预测单张图片,返回置信度最高的结果"""
|
|
try:
|
|
# 处理图片
|
|
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(** inputs)
|
|
|
|
logits = outputs.logits
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
# 获取置信度最高的预测结果
|
|
max_prob, max_idx = probs.max(dim=1)
|
|
class_idx = max_idx.item()
|
|
|
|
return {
|
|
"status": "success",
|
|
"top_prediction": {
|
|
"class_id": class_idx,
|
|
"class_name": self.id2label[class_idx],
|
|
"confidence": max_prob.item()
|
|
}
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"message": str(e)
|
|
}
|
|
|
|
# 初始化服务
|
|
app = Flask(__name__)
|
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型根路径(环境变量或默认路径)
|
|
classifier = ImageClassifier(MODEL_PATH)
|
|
|
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
|
def predict_single():
|
|
"""接收单张图片并返回最高置信度预测结果"""
|
|
# 检查是否有图片上传
|
|
if 'image' not in request.files:
|
|
return jsonify({"status": "error", "message": "请求中未包含图片"}), 400
|
|
|
|
image_file = request.files['image']
|
|
try:
|
|
# 读取图片
|
|
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
|
# 预测
|
|
result = classifier.predict_single_image(image)
|
|
return jsonify(result)
|
|
except Exception as e:
|
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
return jsonify({"status": "healthy", "device": str(device)}), 200
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host='0.0.0.0', port=8000, debug=False) |