125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
import atexit
|
|
import ctypes
|
|
import os
|
|
import sys
|
|
import types
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from .version import __version__
|
|
|
|
|
|
def register_runtime_libraries() -> None:
|
|
try:
|
|
libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so"
|
|
base_prefix = getattr(sys, "base_prefix", sys.prefix)
|
|
if not base_prefix.startswith("/usr"): # like conda or virtualenv
|
|
ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so))
|
|
|
|
this_path = os.path.dirname(os.path.realpath(__file__))
|
|
rt_dll_dpath = os.path.join(this_path, "_vacc_libs")
|
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so"))
|
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so"))
|
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so"))
|
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so"))
|
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so"))
|
|
except Exception as e:
|
|
raise RuntimeError("Vastai runtime library not loaded.") from e
|
|
|
|
|
|
register_runtime_libraries()
|
|
|
|
from ._vacc_libs import _torch_vacc as _C
|
|
|
|
try:
|
|
_C._init_torch_vacc_module()
|
|
except Exception as e:
|
|
raise RuntimeError("Failed to init torch_vacc.") from e
|
|
|
|
|
|
def _apply_patches(monkey_patches):
|
|
def _getattr(module_list, root_module=torch):
|
|
if len(module_list) <= 1:
|
|
return root_module
|
|
|
|
if hasattr(root_module, module_list[0]):
|
|
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
|
|
else:
|
|
empty_module_name = f"{root_module.__name__}.{module_list[0]}"
|
|
sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
|
|
setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
|
|
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
|
|
|
|
for patch_pair in monkey_patches:
|
|
dest, patch = patch_pair
|
|
dest_module = _getattr(dest.split("."), root_module=torch)
|
|
last_module_level = dest.split(".")[-1]
|
|
if not isinstance(patch, types.ModuleType):
|
|
setattr(dest_module, last_module_level, patch)
|
|
continue
|
|
|
|
if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
|
|
setattr(dest_module, last_module_level, patch)
|
|
sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
|
|
continue
|
|
|
|
assert hasattr(patch, "__all__"), "Patch module must have __all__ definition."
|
|
dest_module = getattr(dest_module, last_module_level)
|
|
for attr in patch.__all__:
|
|
setattr(dest_module, attr, getattr(patch, attr))
|
|
|
|
|
|
import torch_vacc.vacc as vacc
|
|
|
|
# register "vacc" module/functions to torch
|
|
torch._register_device_module("vacc", vacc)
|
|
unsupported_dtype = [
|
|
torch.quint8,
|
|
torch.quint4x2,
|
|
torch.quint2x4,
|
|
torch.qint32,
|
|
torch.qint8,
|
|
]
|
|
torch.utils.generate_methods_for_privateuse1_backend(
|
|
for_tensor=True,
|
|
for_module=True,
|
|
for_storage=True, # TODO(qingsong): do we support storage?
|
|
unsupported_dtype=unsupported_dtype,
|
|
)
|
|
|
|
# register legacy *DtypeTensor into torch.vacc
|
|
_C._initialize_python_bindings()
|
|
|
|
# init seed generators, vacc default generator
|
|
vacc.init()
|
|
|
|
|
|
def is_vccl_available() -> bool:
|
|
return True
|
|
|
|
|
|
torch.distributed.is_vccl_available = is_vccl_available
|
|
|
|
|
|
def set_global_log_level(log_level):
|
|
_C.set_global_log_level(log_level.upper())
|
|
|
|
|
|
def print_vacc_ops():
|
|
_C._print_vacc_ops()
|
|
|
|
|
|
def vacc_ops_list():
|
|
return _C._vacc_ops_list().split(",")
|
|
|
|
|
|
def print_vacc_selective_ops():
|
|
_C._print_vacc_selective_ops()
|
|
|
|
|
|
def _vacc_shutdown():
|
|
_C._vacc_module_shutdown()
|
|
|
|
|
|
atexit.register(_vacc_shutdown)
|