Files
enginex-bi_series-vc-cnn/model_test_caltech_http.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

167 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import time
import os
import multiprocessing
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from flask import Flask, request, jsonify
from io import BytesIO
# 设置CPU核心数为4
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
torch.set_num_threads(4) # 设置PyTorch的CPU线程数
# 设备配置
device_cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_cpu = torch.device("cpu")
print(f"当前CUDA设备: {device_cuda}, CPU设备: {device_cpu}")
print(f"CPU核心数设置: {torch.get_num_threads()}")
class ImageClassifier:
def __init__(self, model_path: str):
self.processor = AutoImageProcessor.from_pretrained(model_path)
# 分别加载GPU和CPU模型实例
if device_cuda.type == "cuda":
self.model_cuda = AutoModelForImageClassification.from_pretrained(model_path).to(device_cuda)
else:
self.model_cuda = None # 若没有CUDA则不加载
self.model_cpu = AutoModelForImageClassification.from_pretrained(model_path).to(device_cpu)
# 保存id2label映射
self.id2label = self.model_cpu.config.id2label
def _predict_with_model(self, image, model, device) -> dict:
"""使用指定模型和设备执行预测,包含单独计时"""
try:
# 记录开始时间
start_time = time.perf_counter() # 使用更精确的计时函数
# 处理图片并移动到目标设备
inputs = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(** inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
max_prob, max_idx = probs.max(dim=1)
class_idx = max_idx.item()
# 计算处理时间保留6位小数
processing_time = round(time.perf_counter() - start_time, 6)
return {
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": float(max_prob.item()),
"device_used": str(device),
"processing_time": processing_time # 处理时间
}
except Exception as e:
return {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device),
"processing_time": 0.0,
"error": str(e)
}
def predict_single_image(self, image) -> dict:
"""预测单张图片分别使用GPU和CPU模型"""
results = {"status": "success"}
# GPU预测如果可用
if self.model_cuda is not None:
cuda_result = self._predict_with_model(image, self.model_cuda, device_cuda)
else:
cuda_result = {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": "CUDA设备不可用未加载CUDA模型"
}
results["cuda_prediction"] = cuda_result
# CPU预测已限制为4核心
cpu_result = self._predict_with_model(image, self.model_cpu, device_cpu)
results["cpu_prediction"] = cpu_result
return results
# 初始化服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
classifier = ImageClassifier(MODEL_PATH)
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict_single():
"""接收单张图片并返回预测结果及处理时间"""
if 'image' not in request.files:
return jsonify({
"status": "error",
"cuda_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": "请求中未包含图片"
},
"cpu_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cpu),
"processing_time": 0.0,
"error": "请求中未包含图片"
}
}), 400
image_file = request.files['image']
try:
image = Image.open(BytesIO(image_file.read())).convert("RGB")
result = classifier.predict_single_image(image)
return jsonify(result)
except Exception as e:
return jsonify({
"status": "error",
"cuda_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cuda),
"processing_time": 0.0,
"error": str(e)
},
"cpu_prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device_cpu),
"processing_time": 0.0,
"error": str(e)
}
}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"cuda_available": device_cuda.type == "cuda",
"cuda_device": str(device_cuda),
"cpu_device": str(device_cpu),
"cpu_threads": torch.get_num_threads() # 显示CPU线程数
}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=80, debug=False)