214 lines
7.0 KiB
Python
214 lines
7.0 KiB
Python
|
|
import requests
|
|||
|
|
import json
|
|||
|
|
import torch
|
|||
|
|
from PIL import Image
|
|||
|
|
from io import BytesIO
|
|||
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import subprocess
|
|||
|
|
from flask import Flask, request, jsonify
|
|||
|
|
|
|||
|
|
class ImageClassifier:
|
|||
|
|
def __init__(self, model_path: str, device: torch.device):
|
|||
|
|
"""初始化图像分类器,指定设备"""
|
|||
|
|
# 模型路径有效性校验
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
raise ValueError(f"模型路径不存在: {model_path}")
|
|||
|
|
if not os.path.isdir(model_path):
|
|||
|
|
raise ValueError(f"模型路径不是目录: {model_path}")
|
|||
|
|
|
|||
|
|
# 检查模型必要文件
|
|||
|
|
required_files = ["config.json", "pytorch_model.bin"]
|
|||
|
|
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
|
|||
|
|
if missing_files:
|
|||
|
|
raise ValueError(f"模型路径缺少必要文件: {missing_files}")
|
|||
|
|
|
|||
|
|
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
|||
|
|
self.model = AutoModelForImageClassification.from_pretrained(model_path)
|
|||
|
|
|
|||
|
|
# 将模型移动到指定设备
|
|||
|
|
self.model = self.model.to(device)
|
|||
|
|
self.device = device
|
|||
|
|
|
|||
|
|
# 检查设备类型并打印信息
|
|||
|
|
if device.type == "cuda":
|
|||
|
|
if is_metax_gpu():
|
|||
|
|
print(f"模型是否在沐曦GPU上: {next(self.model.parameters()).device.type == 'cuda'}")
|
|||
|
|
else:
|
|||
|
|
print(f"模型是否在NVIDIA GPU上: {next(self.model.parameters()).is_cuda}")
|
|||
|
|
else:
|
|||
|
|
print(f"模型在 {device.type.upper()} 上运行")
|
|||
|
|
|
|||
|
|
# 多卡处理
|
|||
|
|
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
|||
|
|
self.model = torch.nn.DataParallel(self.model)
|
|||
|
|
|
|||
|
|
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
|
|||
|
|
|
|||
|
|
def predict_single_image(self, image: Image.Image) -> dict:
|
|||
|
|
"""预测单张PIL图片"""
|
|||
|
|
try:
|
|||
|
|
# 预处理
|
|||
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|||
|
|
|
|||
|
|
# 将输入数据移动到设备
|
|||
|
|
inputs = inputs.to(self.device)
|
|||
|
|
|
|||
|
|
# 模型推理
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
ts = time.time()
|
|||
|
|
outputs = self.model(** inputs)
|
|||
|
|
print('muxi T1', time.time() - ts, flush=True)
|
|||
|
|
|
|||
|
|
ts = time.time()
|
|||
|
|
for i in range(1000):
|
|||
|
|
outputs = self.model(**inputs)
|
|||
|
|
print('muxi T2', time.time() - ts, flush=True)
|
|||
|
|
|
|||
|
|
processing_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
# 获取预测结果(只取置信度最高的一个)
|
|||
|
|
logits = outputs.logits
|
|||
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|||
|
|
top_probs, top_indices = probs.topk(1, dim=1)
|
|||
|
|
|
|||
|
|
# 整理结果
|
|||
|
|
class_idx = top_indices[0, 0].item()
|
|||
|
|
confidence = top_probs[0, 0].item()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"class_id": class_idx,
|
|||
|
|
"class_name": self.id2label[class_idx],
|
|||
|
|
"confidence": confidence,
|
|||
|
|
"device_used": str(self.device),
|
|||
|
|
"processing_time": processing_time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"处理图片时出错: {e}")
|
|||
|
|
return {
|
|||
|
|
"class_id": -1,
|
|||
|
|
"class_name": "error",
|
|||
|
|
"confidence": 0.0,
|
|||
|
|
"device_used": str(self.device),
|
|||
|
|
"processing_time": 0.0,
|
|||
|
|
"error": str(e)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def is_metax_gpu():
|
|||
|
|
"""检查是否为沐曦GPU"""
|
|||
|
|
try:
|
|||
|
|
# 检查系统命令 mx-smi
|
|||
|
|
result = subprocess.run(['which', 'mx-smi'], capture_output=True, text=True)
|
|||
|
|
if result.returncode == 0:
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
# 检查PCI设备信息
|
|||
|
|
result = subprocess.run(['lspci'], capture_output=True, text=True)
|
|||
|
|
if 'MetaX' in result.stdout or 'MXC' in result.stdout or '1e66' in result.stdout:
|
|||
|
|
return True
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def check_metax_available():
|
|||
|
|
"""检查沐曦GPU是否可用"""
|
|||
|
|
return torch.cuda.is_available() and is_metax_gpu()
|
|||
|
|
|
|||
|
|
def get_device():
|
|||
|
|
"""获取最佳可用设备(优先沐曦GPU)"""
|
|||
|
|
# 首先检查沐曦GPU
|
|||
|
|
if check_metax_available():
|
|||
|
|
print("检测到沐曦GPU可用")
|
|||
|
|
return torch.device("cuda:0")
|
|||
|
|
|
|||
|
|
# 然后检查NVIDIA GPU
|
|||
|
|
elif torch.cuda.is_available():
|
|||
|
|
print("检测到NVIDIA GPU可用")
|
|||
|
|
return torch.device("cuda:0")
|
|||
|
|
|
|||
|
|
# 最后使用CPU
|
|||
|
|
else:
|
|||
|
|
print("未检测到加速设备,使用CPU")
|
|||
|
|
return torch.device("cpu")
|
|||
|
|
|
|||
|
|
def setup_metax_environment():
|
|||
|
|
"""设置沐曦GPU环境"""
|
|||
|
|
if check_metax_available():
|
|||
|
|
print("正在设置沐曦GPU环境...")
|
|||
|
|
try:
|
|||
|
|
os.environ['MX_VISIBLE_DEVICES'] = '0' # 使用第一张沐曦卡
|
|||
|
|
print("沐曦GPU环境设置完成")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"设置沐曦环境时出错: {e}")
|
|||
|
|
|
|||
|
|
# 初始化服务
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
|
|||
|
|
|
|||
|
|
# 设置沐曦环境并初始化分类器
|
|||
|
|
setup_metax_environment()
|
|||
|
|
device = get_device()
|
|||
|
|
classifier = ImageClassifier(MODEL_PATH, device)
|
|||
|
|
|
|||
|
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
|||
|
|
def predict_single():
|
|||
|
|
"""接收单张图片并返回预测结果"""
|
|||
|
|
if 'image' not in request.files:
|
|||
|
|
return jsonify({
|
|||
|
|
"prediction": {
|
|||
|
|
"class_id": -1,
|
|||
|
|
"class_name": "error",
|
|||
|
|
"confidence": 0.0,
|
|||
|
|
"device_used": str(device),
|
|||
|
|
"processing_time": 0.0,
|
|||
|
|
"error": "请求中未包含图片"
|
|||
|
|
},
|
|||
|
|
"status": "error"
|
|||
|
|
}), 400
|
|||
|
|
|
|||
|
|
image_file = request.files['image']
|
|||
|
|
try:
|
|||
|
|
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
|||
|
|
|
|||
|
|
# 获取预测结果
|
|||
|
|
prediction_result = classifier.predict_single_image(image)
|
|||
|
|
|
|||
|
|
# 构建响应
|
|||
|
|
response = {
|
|||
|
|
"prediction": prediction_result,
|
|||
|
|
"status": "success"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return jsonify(response)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({
|
|||
|
|
"prediction": {
|
|||
|
|
"class_id": -1,
|
|||
|
|
"class_name": "error",
|
|||
|
|
"confidence": 0.0,
|
|||
|
|
"device_used": str(device),
|
|||
|
|
"processing_time": 0.0,
|
|||
|
|
"error": str(e)
|
|||
|
|
},
|
|||
|
|
"status": "error"
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
@app.route('/health', methods=['GET'])
|
|||
|
|
def health_check():
|
|||
|
|
"""健康检查接口"""
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "healthy",
|
|||
|
|
"metax_available": check_metax_available(),
|
|||
|
|
"device_used": str(device),
|
|||
|
|
"cpu_threads": torch.get_num_threads()
|
|||
|
|
}), 200
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
app.run(host='0.0.0.0', port=80, debug=False)
|