44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
||
from functools import wraps
|
||
|
||
def enable_cuda_to_npu_shim():
|
||
print("enable_cuda_to_npu_shim")
|
||
import torch_npu # 注册 npu
|
||
# 仅映射常见函数;不要贪多
|
||
torch.cuda.is_available = torch.npu.is_available
|
||
torch.cuda.device_count = torch.npu.device_count
|
||
torch.cuda.current_device= torch.npu.current_device
|
||
torch.cuda.set_device = torch.npu.set_device
|
||
torch.cuda.synchronize = torch.npu.synchronize
|
||
try:
|
||
# 若存在空缓存接口
|
||
torch.cuda.empty_cache = torch.npu.empty_cache # 某些版本可用
|
||
except Exception:
|
||
pass
|
||
# 设备字符串统一用 npu
|
||
# 业务里仍建议 model.to("npu:0") 显式写清
|
||
try:
|
||
import torch_npu
|
||
if torch.npu.is_available() and not torch.cuda.is_available():
|
||
enable_cuda_to_npu_shim()
|
||
except:
|
||
print("no npu. use native cuda")
|
||
|
||
# 1) 可选:如果你的权重来自 lightning 的 ckpt,放行其类(仅在可信来源时)
|
||
try:
|
||
from torch.serialization import add_safe_globals
|
||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||
add_safe_globals([ModelCheckpoint])
|
||
except Exception:
|
||
pass
|
||
|
||
# 2) 统一把 torch.load 默认映射到 CPU,避免 CUDA 反序列化错误
|
||
_orig_load = torch.load
|
||
def _load_map_to_cpu(*args, **kwargs):
|
||
kwargs.setdefault("map_location", "cpu")
|
||
kwargs.setdefault("weights_only", False)
|
||
return _orig_load(*args, **kwargs)
|
||
torch.load = _load_map_to_cpu
|
||
|
||
|