260 lines
8.9 KiB
Python
260 lines
8.9 KiB
Python
import torch
|
||
import time
|
||
import os
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
from flask import Flask, request, jsonify
|
||
|
||
# 引入寒武纪 MLU 支持
|
||
try:
|
||
import torch_mlu
|
||
print(f"成功导入 torch_mlu,版本: {getattr(torch_mlu, '__version__', 'unknown')}")
|
||
|
||
def check_mlu_available():
|
||
try:
|
||
test_tensor = torch.randn(2, 2).mlu()
|
||
return True
|
||
except Exception as e:
|
||
print(f"MLU不可用: {e}")
|
||
return False
|
||
|
||
MLU_AVAILABLE = check_mlu_available()
|
||
print(f"MLU设备可用: {MLU_AVAILABLE}")
|
||
except ImportError as e:
|
||
torch_mlu = None
|
||
MLU_AVAILABLE = False
|
||
print(f"警告: 未找到 torch_mlu 模块: {e}")
|
||
|
||
|
||
print(f"MLU count: {torch.mlu.device_count()}")
|
||
|
||
|
||
# 设置线程数
|
||
os.environ["OMP_NUM_THREADS"] = "4"
|
||
torch.set_num_threads(4)
|
||
|
||
# 导入 Qwen-VL 特定组件
|
||
try:
|
||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||
from qwen_vl_utils import process_vision_info
|
||
except ImportError:
|
||
raise ImportError("请安装依赖: pip install transformers qwen-vl-utils")
|
||
|
||
# 全局变量
|
||
MODEL_PATH = os.environ.get("MODEL_PATH", "/models")
|
||
|
||
class QwenVLMLUClassifier:
|
||
def __init__(self, model_path: str):
|
||
self.use_mlu = MLU_AVAILABLE
|
||
print(f"初始化模型,使用设备: {'MLU' if self.use_mlu else 'CPU'}")
|
||
|
||
# 加载 processor
|
||
print(f"从 {model_path} 加载处理器...")
|
||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||
|
||
print("加载模型中...")
|
||
|
||
# 关键:使用 device_map="auto" 让 accelerate 自动分配模型到多卡
|
||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||
model_path,
|
||
torch_dtype=torch.float16,
|
||
low_cpu_mem_usage=True,
|
||
device_map="auto", # ✅ 自动分配到所有可用 MLU 卡
|
||
# 注意:不要使用 flash_attention_2(MLU 不支持)
|
||
)
|
||
|
||
# 禁止手动移动模型!device_map 已经完成设备分配
|
||
# ❌ self.model = self.model.mlu() # 会破坏 device_map
|
||
|
||
self.model = self.model.eval()
|
||
print("模型加载完成,已根据 device_map 分配到设备")
|
||
|
||
# 打印每层所在的设备(调试用)
|
||
print("模型各层设备分布:")
|
||
for name, module in self.model.named_modules():
|
||
if hasattr(module, "weight") and module.weight is not None:
|
||
print(f"{name}: {module.weight.device}")
|
||
elif hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device"):
|
||
print(f"{name}: {module._hf_hook.execution_device}")
|
||
if len(list(module.children())) == 0: # 只打印叶节点
|
||
break
|
||
|
||
def predict(self, image: Image.Image, prompt: str = "Describe this image.") -> dict:
|
||
start_time = time.perf_counter()
|
||
|
||
try:
|
||
# 构造 messages
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image", "image": image},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
]
|
||
|
||
# 应用 chat template
|
||
text = self.processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
|
||
# 处理视觉输入
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
|
||
# 构建模型输入
|
||
inputs = self.processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
#truncation=True, # ✅ 启用截断
|
||
#max_length=2048, # ✅ 限制最大长度
|
||
max_pixels=384*384, # ✅ 控制图像大小(关键!)
|
||
return_tensors="pt",
|
||
)
|
||
|
||
# ✅ 只对 image_grid_thw 等字段转 int32(如果存在)
|
||
for key in ['image_grid_thw', 'video_grid_thw']:
|
||
if key in inputs and inputs[key].dtype == torch.long:
|
||
# 检查范围
|
||
val = inputs[key]
|
||
if val.max() >= 2147483647 or val.min() < -2147483648:
|
||
print(f"Warning: {key} out of int32 range, clamping...")
|
||
val = val.clamp(-2147483648, 2147483647)
|
||
inputs[key] = val.to(torch.int32)
|
||
|
||
|
||
|
||
# ✅ 让 generate 自动处理设备(Hugging Face 内部会 dispatch 到正确设备)
|
||
# 不要手动 .to("mlu:x"),否则会出错!
|
||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
ts = time.time()
|
||
generated_ids = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=128,
|
||
# 如果你想强制使用特定设备生成,可以设置:
|
||
# synced_gpus=True, # 多卡同步生成(可选)
|
||
)
|
||
print(f"生成耗时: {time.time() - ts:.3f}s")
|
||
|
||
# 解码输出
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
|
||
]
|
||
output_texts = self.processor.batch_decode(
|
||
generated_ids_trimmed,
|
||
skip_special_tokens=True,
|
||
clean_up_tokenization_spaces=False
|
||
)
|
||
response = output_texts[0].strip()
|
||
|
||
processing_time = round(time.perf_counter() - start_time, 4)
|
||
|
||
return {
|
||
"response": response,
|
||
"device_used": "mlu", # 因为 device_map="auto" 且 MLU 可用
|
||
"processing_time": processing_time,
|
||
"status": "success"
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"response": "",
|
||
"error": str(e),
|
||
"device_used": "mlu" if self.use_mlu else "cpu",
|
||
"processing_time": 0.0,
|
||
"status": "error"
|
||
}
|
||
|
||
|
||
# 初始化 Flask
|
||
app = Flask(__name__)
|
||
classifier = None
|
||
|
||
@app.before_request
|
||
def ensure_model_loaded():
|
||
global classifier
|
||
if request.endpoint == 'predict' and classifier is None:
|
||
return jsonify({"status": "error", "message": "模型未加载"}), 500
|
||
|
||
@app.route('/predict', methods=['POST'])
|
||
def predict():
|
||
if 'image' not in request.files:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "请求中未包含图片"
|
||
}), 400
|
||
|
||
try:
|
||
image_file = request.files['image']
|
||
image = Image.open(BytesIO(image_file.read())).convert("RGB")
|
||
prompt = request.form.get('prompt', 'Describe this image.')
|
||
result = classifier.predict(image, prompt)
|
||
|
||
if result["status"] == "success":
|
||
return jsonify({
|
||
"status": "success",
|
||
"response": result["response"],
|
||
"device_used": result["device_used"],
|
||
"processing_time": result["processing_time"]
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"error": result["error"],
|
||
"device_used": result["device_used"]
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"status": "error",
|
||
"error": f"处理失败: {str(e)}"
|
||
}), 500
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
return jsonify({
|
||
"status": "healthy" if classifier is not None else "unhealthy",
|
||
"model_loaded": classifier is not None,
|
||
"using_mlu": classifier.use_mlu if classifier else False,
|
||
"mlu_available": MLU_AVAILABLE,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
@app.route('/info', methods=['GET'])
|
||
def device_info():
|
||
return jsonify({
|
||
"model": MODEL_PATH,
|
||
"torch_version": torch.__version__,
|
||
"torch_mlu_version": getattr(torch_mlu, '__version__', None) if torch_mlu else None,
|
||
"mlu_available": MLU_AVAILABLE,
|
||
"using_device": "mlu" if classifier and classifier.use_mlu else "cpu",
|
||
"model_loaded": classifier is not None,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
@app.route('/test', methods=['GET'])
|
||
def test_mlu():
|
||
if not MLU_AVAILABLE:
|
||
return jsonify({"status": "error", "message": "MLU不可用"}), 500
|
||
try:
|
||
x = torch.randn(2, 2).mlu()
|
||
y = x + x
|
||
return jsonify({"status": "success", "result": y.cpu().tolist()})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print(f"正在加载模型: {MODEL_PATH}")
|
||
try:
|
||
classifier = QwenVLMLUClassifier(MODEL_PATH)
|
||
print("模型加载成功,启动 Flask 服务...")
|
||
app.run(host='0.0.0.0', port=80, debug=False)
|
||
except Exception as e:
|
||
print(f" 模型加载失败: {e}")
|
||
exit(1)
|