2025-09-05 12:03:13 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
from functools import wraps
|
|
|
|
|
|
|
|
|
|
|
|
def enable_cuda_to_npu_shim():
|
|
|
|
|
|
print("enable_cuda_to_npu_shim")
|
|
|
|
|
|
import torch_npu # 注册 npu
|
|
|
|
|
|
# 仅映射常见函数;不要贪多
|
|
|
|
|
|
torch.cuda.is_available = torch.npu.is_available
|
|
|
|
|
|
torch.cuda.device_count = torch.npu.device_count
|
|
|
|
|
|
torch.cuda.current_device= torch.npu.current_device
|
|
|
|
|
|
torch.cuda.set_device = torch.npu.set_device
|
|
|
|
|
|
torch.cuda.synchronize = torch.npu.synchronize
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 若存在空缓存接口
|
|
|
|
|
|
torch.cuda.empty_cache = torch.npu.empty_cache # 某些版本可用
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
# 设备字符串统一用 npu
|
|
|
|
|
|
# 业务里仍建议 model.to("npu:0") 显式写清
|
|
|
|
|
|
try:
|
|
|
|
|
|
import torch_npu
|
|
|
|
|
|
if torch.npu.is_available() and not torch.cuda.is_available():
|
|
|
|
|
|
enable_cuda_to_npu_shim()
|
|
|
|
|
|
except:
|
2025-09-08 16:32:50 +08:00
|
|
|
|
print("no npu. use native cuda")
|
2025-09-05 12:03:13 +08:00
|
|
|
|
|
|
|
|
|
|
# 1) 可选:如果你的权重来自 lightning 的 ckpt,放行其类(仅在可信来源时)
|
|
|
|
|
|
try:
|
|
|
|
|
|
from torch.serialization import add_safe_globals
|
|
|
|
|
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
|
|
|
|
|
add_safe_globals([ModelCheckpoint])
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# 2) 统一把 torch.load 默认映射到 CPU,避免 CUDA 反序列化错误
|
|
|
|
|
|
_orig_load = torch.load
|
|
|
|
|
|
def _load_map_to_cpu(*args, **kwargs):
|
|
|
|
|
|
kwargs.setdefault("map_location", "cpu")
|
|
|
|
|
|
kwargs.setdefault("weights_only", False)
|
|
|
|
|
|
return _orig_load(*args, **kwargs)
|
|
|
|
|
|
torch.load = _load_map_to_cpu
|
|
|
|
|
|
|
|
|
|
|
|
|