Files
enginex-mlu370-vl/Qwen2.5-VL-32B-Instruct-test.py
2025-10-16 18:33:26 +08:00

260 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import time
import os
from PIL import Image
from io import BytesIO
from flask import Flask, request, jsonify
# 引入寒武纪 MLU 支持
try:
import torch_mlu
print(f"成功导入 torch_mlu版本: {getattr(torch_mlu, '__version__', 'unknown')}")
def check_mlu_available():
try:
test_tensor = torch.randn(2, 2).mlu()
return True
except Exception as e:
print(f"MLU不可用: {e}")
return False
MLU_AVAILABLE = check_mlu_available()
print(f"MLU设备可用: {MLU_AVAILABLE}")
except ImportError as e:
torch_mlu = None
MLU_AVAILABLE = False
print(f"警告: 未找到 torch_mlu 模块: {e}")
print(f"MLU count: {torch.mlu.device_count()}")
# 设置线程数
os.environ["OMP_NUM_THREADS"] = "4"
torch.set_num_threads(4)
# 导入 Qwen-VL 特定组件
try:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
except ImportError:
raise ImportError("请安装依赖: pip install transformers qwen-vl-utils")
# 全局变量
MODEL_PATH = os.environ.get("MODEL_PATH", "/models")
class QwenVLMLUClassifier:
def __init__(self, model_path: str):
self.use_mlu = MLU_AVAILABLE
print(f"初始化模型,使用设备: {'MLU' if self.use_mlu else 'CPU'}")
# 加载 processor
print(f"{model_path} 加载处理器...")
self.processor = AutoProcessor.from_pretrained(model_path)
print("加载模型中...")
# 关键:使用 device_map="auto" 让 accelerate 自动分配模型到多卡
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto", # ✅ 自动分配到所有可用 MLU 卡
# 注意:不要使用 flash_attention_2MLU 不支持)
)
# 禁止手动移动模型device_map 已经完成设备分配
# ❌ self.model = self.model.mlu() # 会破坏 device_map
self.model = self.model.eval()
print("模型加载完成,已根据 device_map 分配到设备")
# 打印每层所在的设备(调试用)
print("模型各层设备分布:")
for name, module in self.model.named_modules():
if hasattr(module, "weight") and module.weight is not None:
print(f"{name}: {module.weight.device}")
elif hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device"):
print(f"{name}: {module._hf_hook.execution_device}")
if len(list(module.children())) == 0: # 只打印叶节点
break
def predict(self, image: Image.Image, prompt: str = "Describe this image.") -> dict:
start_time = time.perf_counter()
try:
# 构造 messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
# 应用 chat template
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 处理视觉输入
image_inputs, video_inputs = process_vision_info(messages)
# 构建模型输入
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
#truncation=True, # ✅ 启用截断
#max_length=2048, # ✅ 限制最大长度
max_pixels=384*384, # ✅ 控制图像大小(关键!)
return_tensors="pt",
)
# ✅ 只对 image_grid_thw 等字段转 int32如果存在
for key in ['image_grid_thw', 'video_grid_thw']:
if key in inputs and inputs[key].dtype == torch.long:
# 检查范围
val = inputs[key]
if val.max() >= 2147483647 or val.min() < -2147483648:
print(f"Warning: {key} out of int32 range, clamping...")
val = val.clamp(-2147483648, 2147483647)
inputs[key] = val.to(torch.int32)
# ✅ 让 generate 自动处理设备Hugging Face 内部会 dispatch 到正确设备)
# 不要手动 .to("mlu:x"),否则会出错!
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
ts = time.time()
generated_ids = self.model.generate(
**inputs,
max_new_tokens=128,
# 如果你想强制使用特定设备生成,可以设置:
# synced_gpus=True, # 多卡同步生成(可选)
)
print(f"生成耗时: {time.time() - ts:.3f}s")
# 解码输出
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_texts = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
response = output_texts[0].strip()
processing_time = round(time.perf_counter() - start_time, 4)
return {
"response": response,
"device_used": "mlu", # 因为 device_map="auto" 且 MLU 可用
"processing_time": processing_time,
"status": "success"
}
except Exception as e:
return {
"response": "",
"error": str(e),
"device_used": "mlu" if self.use_mlu else "cpu",
"processing_time": 0.0,
"status": "error"
}
# 初始化 Flask
app = Flask(__name__)
classifier = None
@app.before_request
def ensure_model_loaded():
global classifier
if request.endpoint == 'predict' and classifier is None:
return jsonify({"status": "error", "message": "模型未加载"}), 500
@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({
"status": "error",
"message": "请求中未包含图片"
}), 400
try:
image_file = request.files['image']
image = Image.open(BytesIO(image_file.read())).convert("RGB")
prompt = request.form.get('prompt', 'Describe this image.')
result = classifier.predict(image, prompt)
if result["status"] == "success":
return jsonify({
"status": "success",
"response": result["response"],
"device_used": result["device_used"],
"processing_time": result["processing_time"]
})
else:
return jsonify({
"status": "error",
"error": result["error"],
"device_used": result["device_used"]
}), 500
except Exception as e:
return jsonify({
"status": "error",
"error": f"处理失败: {str(e)}"
}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy" if classifier is not None else "unhealthy",
"model_loaded": classifier is not None,
"using_mlu": classifier.use_mlu if classifier else False,
"mlu_available": MLU_AVAILABLE,
"timestamp": time.time()
})
@app.route('/info', methods=['GET'])
def device_info():
return jsonify({
"model": MODEL_PATH,
"torch_version": torch.__version__,
"torch_mlu_version": getattr(torch_mlu, '__version__', None) if torch_mlu else None,
"mlu_available": MLU_AVAILABLE,
"using_device": "mlu" if classifier and classifier.use_mlu else "cpu",
"model_loaded": classifier is not None,
"timestamp": time.time()
})
@app.route('/test', methods=['GET'])
def test_mlu():
if not MLU_AVAILABLE:
return jsonify({"status": "error", "message": "MLU不可用"}), 500
try:
x = torch.randn(2, 2).mlu()
y = x + x
return jsonify({"status": "success", "result": y.cpu().tolist()})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
if __name__ == "__main__":
print(f"正在加载模型: {MODEL_PATH}")
try:
classifier = QwenVLMLUClassifier(MODEL_PATH)
print("模型加载成功,启动 Flask 服务...")
app.run(host='0.0.0.0', port=80, debug=False)
except Exception as e:
print(f" 模型加载失败: {e}")
exit(1)