import torch from functools import wraps def enable_cuda_to_npu_shim(): print("enable_cuda_to_npu_shim") import torch_npu # 注册 npu # 仅映射常见函数;不要贪多 torch.cuda.is_available = torch.npu.is_available torch.cuda.device_count = torch.npu.device_count torch.cuda.current_device= torch.npu.current_device torch.cuda.set_device = torch.npu.set_device torch.cuda.synchronize = torch.npu.synchronize try: # 若存在空缓存接口 torch.cuda.empty_cache = torch.npu.empty_cache # 某些版本可用 except Exception: pass # 设备字符串统一用 npu # 业务里仍建议 model.to("npu:0") 显式写清 try: import torch_npu if torch.npu.is_available() and not torch.cuda.is_available(): enable_cuda_to_npu_shim() except: print("no npu. use native cuda") # 1) 可选:如果你的权重来自 lightning 的 ckpt,放行其类(仅在可信来源时) try: from torch.serialization import add_safe_globals from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint add_safe_globals([ModelCheckpoint]) except Exception: pass # 2) 统一把 torch.load 默认映射到 CPU,避免 CUDA 反序列化错误 _orig_load = torch.load def _load_map_to_cpu(*args, **kwargs): kwargs.setdefault("map_location", "cpu") kwargs.setdefault("weights_only", False) return _orig_load(*args, **kwargs) torch.load = _load_map_to_cpu