Files
enginex-c_series-vc-cnn/model_test_caltech_http_muxi.py

214 lines
7.0 KiB
Python
Raw Permalink Normal View History

2025-09-18 15:32:10 +08:00
import requests
import json
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
import time
import subprocess
from flask import Flask, request, jsonify
class ImageClassifier:
def __init__(self, model_path: str, device: torch.device):
"""初始化图像分类器,指定设备"""
# 模型路径有效性校验
if not os.path.exists(model_path):
raise ValueError(f"模型路径不存在: {model_path}")
if not os.path.isdir(model_path):
raise ValueError(f"模型路径不是目录: {model_path}")
# 检查模型必要文件
required_files = ["config.json", "pytorch_model.bin"]
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
raise ValueError(f"模型路径缺少必要文件: {missing_files}")
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = AutoModelForImageClassification.from_pretrained(model_path)
# 将模型移动到指定设备
self.model = self.model.to(device)
self.device = device
# 检查设备类型并打印信息
if device.type == "cuda":
if is_metax_gpu():
print(f"模型是否在沐曦GPU上: {next(self.model.parameters()).device.type == 'cuda'}")
else:
print(f"模型是否在NVIDIA GPU上: {next(self.model.parameters()).is_cuda}")
else:
print(f"模型在 {device.type.upper()} 上运行")
# 多卡处理
if device.type == "cuda" and torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model)
self.id2label = self.model.module.config.id2label if hasattr(self.model, 'module') else self.model.config.id2label
def predict_single_image(self, image: Image.Image) -> dict:
"""预测单张PIL图片"""
try:
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
# 将输入数据移动到设备
inputs = inputs.to(self.device)
# 模型推理
start_time = time.time()
with torch.no_grad():
ts = time.time()
outputs = self.model(** inputs)
print('muxi T1', time.time() - ts, flush=True)
ts = time.time()
for i in range(1000):
outputs = self.model(**inputs)
print('muxi T2', time.time() - ts, flush=True)
processing_time = time.time() - start_time
# 获取预测结果(只取置信度最高的一个)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(1, dim=1)
# 整理结果
class_idx = top_indices[0, 0].item()
confidence = top_probs[0, 0].item()
return {
"class_id": class_idx,
"class_name": self.id2label[class_idx],
"confidence": confidence,
"device_used": str(self.device),
"processing_time": processing_time
}
except Exception as e:
print(f"处理图片时出错: {e}")
return {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(self.device),
"processing_time": 0.0,
"error": str(e)
}
def is_metax_gpu():
"""检查是否为沐曦GPU"""
try:
# 检查系统命令 mx-smi
result = subprocess.run(['which', 'mx-smi'], capture_output=True, text=True)
if result.returncode == 0:
return True
# 检查PCI设备信息
result = subprocess.run(['lspci'], capture_output=True, text=True)
if 'MetaX' in result.stdout or 'MXC' in result.stdout or '1e66' in result.stdout:
return True
except:
pass
return False
def check_metax_available():
"""检查沐曦GPU是否可用"""
return torch.cuda.is_available() and is_metax_gpu()
def get_device():
"""获取最佳可用设备优先沐曦GPU"""
# 首先检查沐曦GPU
if check_metax_available():
print("检测到沐曦GPU可用")
return torch.device("cuda:0")
# 然后检查NVIDIA GPU
elif torch.cuda.is_available():
print("检测到NVIDIA GPU可用")
return torch.device("cuda:0")
# 最后使用CPU
else:
print("未检测到加速设备使用CPU")
return torch.device("cpu")
def setup_metax_environment():
"""设置沐曦GPU环境"""
if check_metax_available():
print("正在设置沐曦GPU环境...")
try:
os.environ['MX_VISIBLE_DEVICES'] = '0' # 使用第一张沐曦卡
print("沐曦GPU环境设置完成")
except Exception as e:
print(f"设置沐曦环境时出错: {e}")
# 初始化服务
app = Flask(__name__)
MODEL_PATH = os.environ.get("MODEL_PATH", "/model") # 模型路径(环境变量或默认路径)
# 设置沐曦环境并初始化分类器
setup_metax_environment()
device = get_device()
classifier = ImageClassifier(MODEL_PATH, device)
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict_single():
"""接收单张图片并返回预测结果"""
if 'image' not in request.files:
return jsonify({
"prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device),
"processing_time": 0.0,
"error": "请求中未包含图片"
},
"status": "error"
}), 400
image_file = request.files['image']
try:
image = Image.open(BytesIO(image_file.read())).convert("RGB")
# 获取预测结果
prediction_result = classifier.predict_single_image(image)
# 构建响应
response = {
"prediction": prediction_result,
"status": "success"
}
return jsonify(response)
except Exception as e:
return jsonify({
"prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": str(device),
"processing_time": 0.0,
"error": str(e)
},
"status": "error"
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
return jsonify({
"status": "healthy",
"metax_available": check_metax_available(),
"device_used": str(device),
"cpu_threads": torch.get_num_threads()
}), 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=80, debug=False)