import torch import time import os import multiprocessing from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification from flask import Flask, request, jsonify from io import BytesIO # 设置CPU核心数为4 os.environ["OMP_NUM_THREADS"] = "4" os.environ["MKL_NUM_THREADS"] = "4" os.environ["NUMEXPR_NUM_THREADS"] = "4" os.environ["OPENBLAS_NUM_THREADS"] = "4" os.environ["VECLIB_MAXIMUM_THREADS"] = "4" torch.set_num_threads(4) # 设置PyTorch的CPU线程数 # 设备配置 device_cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_cpu = torch.device("cpu") print(f"当前CUDA设备: {device_cuda}, CPU设备: {device_cpu}") print(f"CPU核心数设置: {torch.get_num_threads()}") class ImageClassifier: def __init__(self, model_path: str): self.processor = AutoImageProcessor.from_pretrained(model_path) # 分别加载GPU和CPU模型实例 if device_cuda.type == "cuda": self.model_cuda = AutoModelForImageClassification.from_pretrained(model_path).to(device_cuda) else: self.model_cuda = None # 若没有CUDA,则不加载 self.model_cpu = AutoModelForImageClassification.from_pretrained(model_path).to(device_cpu) # 保存id2label映射 self.id2label = self.model_cpu.config.id2label def _predict_with_model(self, image, model, device) -> dict: """使用指定模型和设备执行预测,包含单独计时""" try: # 记录开始时间 start_time = time.perf_counter() # 使用更精确的计时函数 # 处理图片并移动到目标设备 inputs = self.processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(** inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1) max_prob, max_idx = probs.max(dim=1) class_idx = max_idx.item() # 计算处理时间(秒),保留6位小数 processing_time = round(time.perf_counter() - start_time, 6) return { "class_id": class_idx, "class_name": self.id2label[class_idx], "confidence": float(max_prob.item()), "device_used": str(device), "processing_time": processing_time # 处理时间 } except Exception as e: return { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device), "processing_time": 0.0, "error": str(e) } def predict_single_image(self, image) -> dict: """预测单张图片,分别使用GPU和CPU模型""" results = {"status": "success"} # GPU预测(如果可用) if self.model_cuda is not None: cuda_result = self._predict_with_model(image, self.model_cuda, device_cuda) else: cuda_result = { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device_cuda), "processing_time": 0.0, "error": "CUDA设备不可用,未加载CUDA模型" } results["cuda_prediction"] = cuda_result # CPU预测(已限制为4核心) cpu_result = self._predict_with_model(image, self.model_cpu, device_cpu) results["cpu_prediction"] = cpu_result return results # 初始化服务 app = Flask(__name__) MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径) classifier = ImageClassifier(MODEL_PATH) @app.route('/v1/private/s782b4996', methods=['POST']) def predict_single(): """接收单张图片并返回预测结果及处理时间""" if 'image' not in request.files: return jsonify({ "status": "error", "cuda_prediction": { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device_cuda), "processing_time": 0.0, "error": "请求中未包含图片" }, "cpu_prediction": { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device_cpu), "processing_time": 0.0, "error": "请求中未包含图片" } }), 400 image_file = request.files['image'] try: image = Image.open(BytesIO(image_file.read())).convert("RGB") result = classifier.predict_single_image(image) return jsonify(result) except Exception as e: return jsonify({ "status": "error", "cuda_prediction": { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device_cuda), "processing_time": 0.0, "error": str(e) }, "cpu_prediction": { "class_id": -1, "class_name": "error", "confidence": 0.0, "device_used": str(device_cpu), "processing_time": 0.0, "error": str(e) } }), 500 @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "cuda_available": device_cuda.type == "cuda", "cuda_device": str(device_cuda), "cpu_device": str(device_cpu), "cpu_threads": torch.get_num_threads() # 显示CPU线程数 }), 200 if __name__ == "__main__": app.run(host='0.0.0.0', port=80, debug=False)