import atexit import ctypes import os import sys import types import torch import torch.distributed from .version import __version__ def register_runtime_libraries() -> None: try: libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so" base_prefix = getattr(sys, "base_prefix", sys.prefix) if not base_prefix.startswith("/usr"): # like conda or virtualenv ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so)) this_path = os.path.dirname(os.path.realpath(__file__)) rt_dll_dpath = os.path.join(this_path, "_vacc_libs") ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so")) ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so")) ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so")) ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so")) ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so")) except Exception as e: raise RuntimeError("Vastai runtime library not loaded.") from e register_runtime_libraries() from ._vacc_libs import _torch_vacc as _C try: _C._init_torch_vacc_module() except Exception as e: raise RuntimeError("Failed to init torch_vacc.") from e def _apply_patches(monkey_patches): def _getattr(module_list, root_module=torch): if len(module_list) <= 1: return root_module if hasattr(root_module, module_list[0]): return _getattr(module_list[1:], getattr(root_module, module_list[0])) else: empty_module_name = f"{root_module.__name__}.{module_list[0]}" sys.modules[empty_module_name] = types.ModuleType(empty_module_name) setattr(root_module, module_list[0], sys.modules.get(empty_module_name)) return _getattr(module_list[1:], getattr(root_module, module_list[0])) for patch_pair in monkey_patches: dest, patch = patch_pair dest_module = _getattr(dest.split("."), root_module=torch) last_module_level = dest.split(".")[-1] if not isinstance(patch, types.ModuleType): setattr(dest_module, last_module_level, patch) continue if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"): setattr(dest_module, last_module_level, patch) sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch continue assert hasattr(patch, "__all__"), "Patch module must have __all__ definition." dest_module = getattr(dest_module, last_module_level) for attr in patch.__all__: setattr(dest_module, attr, getattr(patch, attr)) import torch_vacc.vacc as vacc # register "vacc" module/functions to torch torch._register_device_module("vacc", vacc) unsupported_dtype = [ torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8, ] torch.utils.generate_methods_for_privateuse1_backend( for_tensor=True, for_module=True, for_storage=True, # TODO(qingsong): do we support storage? unsupported_dtype=unsupported_dtype, ) # register legacy *DtypeTensor into torch.vacc _C._initialize_python_bindings() # init seed generators, vacc default generator vacc.init() def is_vccl_available() -> bool: return True torch.distributed.is_vccl_available = is_vccl_available def set_global_log_level(log_level): _C.set_global_log_level(log_level.upper()) def print_vacc_ops(): _C._print_vacc_ops() def vacc_ops_list(): return _C._vacc_ops_list().split(",") def print_vacc_selective_ops(): _C._print_vacc_selective_ops() def _vacc_shutdown(): _C._vacc_module_shutdown() atexit.register(_vacc_shutdown)