Files
enginex-mlu370-vc/model_test_caltech_http_mlu370.py
zhousha b7dce67c2c update
2025-09-15 16:08:27 +08:00

378 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import time
import os
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from flask import Flask, request, jsonify
from io import BytesIO
# 引入寒武纪MLU相关模块
try:
import torch_mlu
print(f"成功导入torch_mlu版本: {getattr(torch_mlu, '__version__', 'unknown')}")
# 简单的MLU可用性测试
def check_mlu_available():
try:
test_tensor = torch.randn(2, 2).mlu()
return True
except:
return False
def get_mlu_device_count():
"""通过尝试多个设备索引来检测可用的MLU设备数量"""
max_devices_to_check = 8
available_devices = 0
for i in range(max_devices_to_check):
try:
test_tensor = torch.randn(2, 2).mlu(i)
available_devices += 1
print(f"MLU设备 {i} 可用")
except:
break
return available_devices
def get_device_name(device_index):
"""获取设备名称"""
try:
return f"MLU-Device-{device_index}"
except:
return f"MLU-Device-{device_index} (Unknown)"
# 创建模拟的ct模块
class MLUModel:
@staticmethod
def is_mlu_available():
return check_mlu_available()
@staticmethod
def device_count():
return get_mlu_device_count()
@staticmethod
def get_device_name(device_index):
return get_device_name(device_index)
ct = MLUModel()
MLU_AVAILABLE = check_mlu_available()
print(f"MLU设备可用: {MLU_AVAILABLE}")
print(f"检测到 {ct.device_count()} 个MLU设备")
except ImportError:
torch_mlu = None
ct = None
MLU_AVAILABLE = False
print("警告: 未找到torch_mlu模块无法使用MLU设备")
except Exception as e:
torch_mlu = None
ct = None
MLU_AVAILABLE = False
print(f"MLU初始化警告: {str(e)}")
# 设置CPU核心数
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
torch.set_num_threads(4)
class MLUImageClassifier:
def __init__(self, model_path: str):
# 检测并使用MLU设备
self.use_mlu = self._check_mlu_availability()
print(f"使用设备: {'MLU' if self.use_mlu else 'CPU'}")
# 加载处理器和模型
self.processor = AutoImageProcessor.from_pretrained(model_path)
self.model = self._load_model(model_path)
self.id2label = self.model.config.id2label
# 验证模型设备
self._verify_model_device()
def _check_mlu_availability(self):
"""检查MLU设备是否可用"""
if torch_mlu is None:
print("MLU不可用: torch_mlu模块未找到")
return False
try:
# 测试MLU基本功能
test_tensor = torch.randn(2, 2).mlu()
test_result = test_tensor + test_tensor
print("MLU设备可用性测试通过")
return True
except Exception as e:
print(f"MLU设备测试失败: {e}")
return False
def _load_model(self, model_path: str) -> AutoModelForImageClassification:
"""加载模型到合适的设备"""
try:
# 先在CPU加载模型
model = AutoModelForImageClassification.from_pretrained(
model_path,
torch_dtype=torch.float32
)
if self.use_mlu:
# 先将模型完全移动到CPU确保稳定
model = model.cpu()
# 使用.mlu()方法将模型移动到MLU设备
model = model.mlu()
print("模型成功加载到MLU设备")
else:
model = model.cpu()
print("模型加载到CPU设备")
return model.eval()
except Exception as e:
print(f"模型加载失败: {str(e)}")
# 尝试fallback到CPU模式
try:
model = model.cpu()
print("Fallback到CPU模式")
return model.eval()
except:
raise RuntimeError(f"模型加载完全失败: {str(e)}")
def _verify_model_device(self):
"""验证模型设备"""
try:
param = next(self.model.parameters())
if self.use_mlu:
# 对于MLU设备通过简单操作验证
test_output = param + 0
print("MLU模型验证成功")
else:
print("CPU模型验证成功")
except StopIteration:
print("警告: 模型没有可训练参数")
except Exception as e:
print(f"模型验证警告: {e}")
def _predict_with_mlu(self, image) -> dict:
"""在MLU上执行推理"""
try:
start_time = time.perf_counter()
# 预处理
inputs = self.processor(images=image, return_tensors="pt")
if self.use_mlu:
# 将输入数据移动到MLU
inputs_mlu = {}
for key, value in inputs.items():
if hasattr(value, 'mlu'):
inputs_mlu[key] = value.mlu()
else:
inputs_mlu[key] = value
# 执行推理
with torch.no_grad():
# 首次推理(热身)
ts = time.time()
outputs = self.model(**inputs_mlu)
#first_pass_time = time.time() - ts
print('mlu370 T1', time.time() - ts, flush=True)
# 多次推理(性能测试)
ts = time.time()
#for _ in range(5): # 减少测试次数
for i in range(800):
outputs = self.model(**inputs_mlu)
#batch_pass_time = time.time() - ts
print('mlu370 T2', time.time() - ts, flush=True)
else:
# CPU推理
with torch.no_grad():
ts = time.time()
outputs = self.model(**inputs)
#first_pass_time = time.time() - ts
print('cpu T1', time.time() - ts, flush=True)
ts = time.time()
#for _ in range(5):
outputs = self.model(**inputs)
#batch_pass_time = time.time() - ts
print('cpu T2', time.time() - ts, flush=True)
# 计算结果
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
max_prob, max_idx = probs.max(dim=-1)
class_idx = max_idx.item()
processing_time = round(time.perf_counter() - start_time, 6)
return {
"class_id": class_idx,
"class_name": self.id2label.get(class_idx, f"class_{class_idx}"),
"confidence": float(max_prob.item()),
"device_used": "mlu" if self.use_mlu else "cpu",
"processing_time": processing_time
}
except Exception as e:
return {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": "mlu" if self.use_mlu else "cpu",
"processing_time": 0.0,
"error": str(e)
}
def predict(self, image) -> dict:
"""预测入口"""
return self._predict_with_mlu(image)
# 初始化Flask应用
app = Flask(__name__)
# 全局模型加载
try:
MODEL_PATH = os.environ.get("MODEL_PATH", "/model")
print(f"从路径加载模型: {MODEL_PATH}")
classifier = MLUImageClassifier(MODEL_PATH)
print("模型加载成功")
except Exception as e:
print(f"服务初始化失败: {str(e)}")
classifier = None
@app.route('/v1/private/s782b4996', methods=['POST'])
def predict():
"""接收单张图片并返回GPU预测结果"""
if classifier is None:
return jsonify({
"status": "error",
"prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": "unknown",
"processing_time": 0.0,
"error": "服务未初始化成功"
}
}), 500
if 'image' not in request.files:
return jsonify({
"status": "error",
"prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": "mlu" if classifier.use_mlu else "cpu",
"processing_time": 0.0,
"error": "请求中未包含图片"
}
}), 400
try:
image_file = request.files['image']
image = Image.open(BytesIO(image_file.read())).convert("RGB")
result = classifier.predict(image)
if 'error' in result:
return jsonify({
"status": "error",
"prediction": result
}), 500
else:
return jsonify({
"status": "success",
"prediction": result
})
except Exception as e:
return jsonify({
"status": "error",
"prediction": {
"class_id": -1,
"class_name": "error",
"confidence": 0.0,
"device_used": "mlu" if classifier and classifier.use_mlu else "cpu",
"processing_time": 0.0,
"error": f"处理图片失败: {str(e)}"
}
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
mlu_available = False
mlu_info = {}
if torch_mlu is not None and hasattr(ct, 'is_mlu_available'):
try:
mlu_available = ct.is_mlu_available()
mlu_info = {
"device_count": ct.device_count(),
"devices": [ct.get_device_name(i) for i in range(ct.device_count())]
}
except Exception as e:
mlu_info["error"] = str(e)
return jsonify({
"status": "healthy" if classifier is not None else "degraded",
"mlu_available": mlu_available,
"mlu_info": mlu_info,
"model_loaded": classifier is not None,
"using_mlu": classifier.use_mlu if classifier else False,
"timestamp": time.time()
})
@app.route('/test', methods=['GET'])
def test_mlu():
"""MLU测试接口"""
try:
if torch_mlu is None:
return jsonify({
"status": "error",
"message": "torch_mlu模块未找到",
"mlu_working": False
}), 500
# 测试MLU基本功能
test_tensor = torch.randn(3, 3).mlu()
result_tensor = test_tensor + test_tensor
result_cpu = result_tensor.cpu()
return jsonify({
"status": "success",
"message": "MLU测试通过",
"result_shape": str(result_cpu.shape),
"mlu_working": True
})
except Exception as e:
return jsonify({
"status": "error",
"message": f"MLU测试失败: {str(e)}",
"mlu_working": False
}), 500
@app.route('/info', methods=['GET'])
def device_info():
"""设备信息接口"""
info = {
"pytorch_version": torch.__version__,
"torch_mlu_available": torch_mlu is not None,
"mlu_devices_count": ct.device_count() if torch_mlu and hasattr(ct, 'device_count') else 0,
"model_loaded": classifier is not None,
"using_mlu": classifier.use_mlu if classifier else False,
"system_time": time.time()
}
return jsonify(info)
if __name__ == "__main__":
# 启动HTTP服务 - 使用Flask内置服务器
print("启动MLU图像分类服务...")
app.run(host='0.0.0.0', port=80, debug=False)