import torch import time import os from PIL import Image from io import BytesIO from flask import Flask, request, jsonify # 引入寒武纪 MLU 支持 try: import torch_mlu print(f"成功导入 torch_mlu,版本: {getattr(torch_mlu, '__version__', 'unknown')}") def check_mlu_available(): try: test_tensor = torch.randn(2, 2).mlu() return True except Exception as e: print(f"MLU不可用: {e}") return False MLU_AVAILABLE = check_mlu_available() print(f"MLU设备可用: {MLU_AVAILABLE}") except ImportError as e: torch_mlu = None MLU_AVAILABLE = False print(f"警告: 未找到 torch_mlu 模块: {e}") print(f"MLU count: {torch.mlu.device_count()}") # 设置线程数 os.environ["OMP_NUM_THREADS"] = "4" torch.set_num_threads(4) # 导入 Qwen-VL 特定组件 try: from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info except ImportError: raise ImportError("请安装依赖: pip install transformers qwen-vl-utils") # 全局变量 MODEL_PATH = os.environ.get("MODEL_PATH", "/models") class QwenVLMLUClassifier: def __init__(self, model_path: str): self.use_mlu = MLU_AVAILABLE print(f"初始化模型,使用设备: {'MLU' if self.use_mlu else 'CPU'}") # 加载 processor print(f"从 {model_path} 加载处理器...") self.processor = AutoProcessor.from_pretrained(model_path) print("加载模型中...") # 关键:使用 device_map="auto" 让 accelerate 自动分配模型到多卡 self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", # ✅ 自动分配到所有可用 MLU 卡 # 注意:不要使用 flash_attention_2(MLU 不支持) ) # 禁止手动移动模型!device_map 已经完成设备分配 # ❌ self.model = self.model.mlu() # 会破坏 device_map self.model = self.model.eval() print("模型加载完成,已根据 device_map 分配到设备") # 打印每层所在的设备(调试用) print("模型各层设备分布:") for name, module in self.model.named_modules(): if hasattr(module, "weight") and module.weight is not None: print(f"{name}: {module.weight.device}") elif hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device"): print(f"{name}: {module._hf_hook.execution_device}") if len(list(module.children())) == 0: # 只打印叶节点 break def predict(self, image: Image.Image, prompt: str = "Describe this image.") -> dict: start_time = time.perf_counter() try: # 构造 messages messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] # 应用 chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # 处理视觉输入 image_inputs, video_inputs = process_vision_info(messages) # 构建模型输入 inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, #truncation=True, # ✅ 启用截断 #max_length=2048, # ✅ 限制最大长度 max_pixels=384*384, # ✅ 控制图像大小(关键!) return_tensors="pt", ) # ✅ 只对 image_grid_thw 等字段转 int32(如果存在) for key in ['image_grid_thw', 'video_grid_thw']: if key in inputs and inputs[key].dtype == torch.long: # 检查范围 val = inputs[key] if val.max() >= 2147483647 or val.min() < -2147483648: print(f"Warning: {key} out of int32 range, clamping...") val = val.clamp(-2147483648, 2147483647) inputs[key] = val.to(torch.int32) # ✅ 让 generate 自动处理设备(Hugging Face 内部会 dispatch 到正确设备) # 不要手动 .to("mlu:x"),否则会出错! inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): ts = time.time() generated_ids = self.model.generate( **inputs, max_new_tokens=128, # 如果你想强制使用特定设备生成,可以设置: # synced_gpus=True, # 多卡同步生成(可选) ) print(f"生成耗时: {time.time() - ts:.3f}s") # 解码输出 generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] output_texts = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) response = output_texts[0].strip() processing_time = round(time.perf_counter() - start_time, 4) return { "response": response, "device_used": "mlu", # 因为 device_map="auto" 且 MLU 可用 "processing_time": processing_time, "status": "success" } except Exception as e: return { "response": "", "error": str(e), "device_used": "mlu" if self.use_mlu else "cpu", "processing_time": 0.0, "status": "error" } # 初始化 Flask app = Flask(__name__) classifier = None @app.before_request def ensure_model_loaded(): global classifier if request.endpoint == 'predict' and classifier is None: return jsonify({"status": "error", "message": "模型未加载"}), 500 @app.route('/predict', methods=['POST']) def predict(): if 'image' not in request.files: return jsonify({ "status": "error", "message": "请求中未包含图片" }), 400 try: image_file = request.files['image'] image = Image.open(BytesIO(image_file.read())).convert("RGB") prompt = request.form.get('prompt', 'Describe this image.') result = classifier.predict(image, prompt) if result["status"] == "success": return jsonify({ "status": "success", "response": result["response"], "device_used": result["device_used"], "processing_time": result["processing_time"] }) else: return jsonify({ "status": "error", "error": result["error"], "device_used": result["device_used"] }), 500 except Exception as e: return jsonify({ "status": "error", "error": f"处理失败: {str(e)}" }), 500 @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "healthy" if classifier is not None else "unhealthy", "model_loaded": classifier is not None, "using_mlu": classifier.use_mlu if classifier else False, "mlu_available": MLU_AVAILABLE, "timestamp": time.time() }) @app.route('/info', methods=['GET']) def device_info(): return jsonify({ "model": MODEL_PATH, "torch_version": torch.__version__, "torch_mlu_version": getattr(torch_mlu, '__version__', None) if torch_mlu else None, "mlu_available": MLU_AVAILABLE, "using_device": "mlu" if classifier and classifier.use_mlu else "cpu", "model_loaded": classifier is not None, "timestamp": time.time() }) @app.route('/test', methods=['GET']) def test_mlu(): if not MLU_AVAILABLE: return jsonify({"status": "error", "message": "MLU不可用"}), 500 try: x = torch.randn(2, 2).mlu() y = x + x return jsonify({"status": "success", "result": y.cpu().tolist()}) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 if __name__ == "__main__": print(f"正在加载模型: {MODEL_PATH}") try: classifier = QwenVLMLUClassifier(MODEL_PATH) print("模型加载成功,启动 Flask 服务...") app.run(host='0.0.0.0', port=80, debug=False) except Exception as e: print(f" 模型加载失败: {e}") exit(1)