Files
2025-09-08 16:32:50 +08:00

44 lines
1.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from functools import wraps
def enable_cuda_to_npu_shim():
print("enable_cuda_to_npu_shim")
import torch_npu # 注册 npu
# 仅映射常见函数;不要贪多
torch.cuda.is_available = torch.npu.is_available
torch.cuda.device_count = torch.npu.device_count
torch.cuda.current_device= torch.npu.current_device
torch.cuda.set_device = torch.npu.set_device
torch.cuda.synchronize = torch.npu.synchronize
try:
# 若存在空缓存接口
torch.cuda.empty_cache = torch.npu.empty_cache # 某些版本可用
except Exception:
pass
# 设备字符串统一用 npu
# 业务里仍建议 model.to("npu:0") 显式写清
try:
import torch_npu
if torch.npu.is_available() and not torch.cuda.is_available():
enable_cuda_to_npu_shim()
except:
print("no npu. use native cuda")
# 1) 可选:如果你的权重来自 lightning 的 ckpt放行其类仅在可信来源时
try:
from torch.serialization import add_safe_globals
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
add_safe_globals([ModelCheckpoint])
except Exception:
pass
# 2) 统一把 torch.load 默认映射到 CPU避免 CUDA 反序列化错误
_orig_load = torch.load
def _load_map_to_cpu(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _load_map_to_cpu