Files
enginex-ascend-910-text2video/patch.py

44 lines
1.4 KiB
Python
Raw Permalink Normal View History

2025-09-05 12:03:13 +08:00
import torch
from functools import wraps
def enable_cuda_to_npu_shim():
print("enable_cuda_to_npu_shim")
import torch_npu # 注册 npu
# 仅映射常见函数;不要贪多
torch.cuda.is_available = torch.npu.is_available
torch.cuda.device_count = torch.npu.device_count
torch.cuda.current_device= torch.npu.current_device
torch.cuda.set_device = torch.npu.set_device
torch.cuda.synchronize = torch.npu.synchronize
try:
# 若存在空缓存接口
torch.cuda.empty_cache = torch.npu.empty_cache # 某些版本可用
except Exception:
pass
# 设备字符串统一用 npu
# 业务里仍建议 model.to("npu:0") 显式写清
try:
import torch_npu
if torch.npu.is_available() and not torch.cuda.is_available():
enable_cuda_to_npu_shim()
except:
2025-09-08 16:32:50 +08:00
print("no npu. use native cuda")
2025-09-05 12:03:13 +08:00
# 1) 可选:如果你的权重来自 lightning 的 ckpt放行其类仅在可信来源时
try:
from torch.serialization import add_safe_globals
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
add_safe_globals([ModelCheckpoint])
except Exception:
pass
# 2) 统一把 torch.load 默认映射到 CPU避免 CUDA 反序列化错误
_orig_load = torch.load
def _load_map_to_cpu(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _load_map_to_cpu