Files
enginex-vastai-va16-vllm/torch_vacc/__init__.py
2026-04-02 04:55:00 +00:00

125 lines
3.6 KiB
Python

import atexit
import ctypes
import os
import sys
import types
import torch
import torch.distributed
from .version import __version__
def register_runtime_libraries() -> None:
try:
libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so"
base_prefix = getattr(sys, "base_prefix", sys.prefix)
if not base_prefix.startswith("/usr"): # like conda or virtualenv
ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so))
this_path = os.path.dirname(os.path.realpath(__file__))
rt_dll_dpath = os.path.join(this_path, "_vacc_libs")
ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so"))
except Exception as e:
raise RuntimeError("Vastai runtime library not loaded.") from e
register_runtime_libraries()
from ._vacc_libs import _torch_vacc as _C
try:
_C._init_torch_vacc_module()
except Exception as e:
raise RuntimeError("Failed to init torch_vacc.") from e
def _apply_patches(monkey_patches):
def _getattr(module_list, root_module=torch):
if len(module_list) <= 1:
return root_module
if hasattr(root_module, module_list[0]):
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
else:
empty_module_name = f"{root_module.__name__}.{module_list[0]}"
sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
for patch_pair in monkey_patches:
dest, patch = patch_pair
dest_module = _getattr(dest.split("."), root_module=torch)
last_module_level = dest.split(".")[-1]
if not isinstance(patch, types.ModuleType):
setattr(dest_module, last_module_level, patch)
continue
if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
setattr(dest_module, last_module_level, patch)
sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
continue
assert hasattr(patch, "__all__"), "Patch module must have __all__ definition."
dest_module = getattr(dest_module, last_module_level)
for attr in patch.__all__:
setattr(dest_module, attr, getattr(patch, attr))
import torch_vacc.vacc as vacc
# register "vacc" module/functions to torch
torch._register_device_module("vacc", vacc)
unsupported_dtype = [
torch.quint8,
torch.quint4x2,
torch.quint2x4,
torch.qint32,
torch.qint8,
]
torch.utils.generate_methods_for_privateuse1_backend(
for_tensor=True,
for_module=True,
for_storage=True, # TODO(qingsong): do we support storage?
unsupported_dtype=unsupported_dtype,
)
# register legacy *DtypeTensor into torch.vacc
_C._initialize_python_bindings()
# init seed generators, vacc default generator
vacc.init()
def is_vccl_available() -> bool:
return True
torch.distributed.is_vccl_available = is_vccl_available
def set_global_log_level(log_level):
_C.set_global_log_level(log_level.upper())
def print_vacc_ops():
_C._print_vacc_ops()
def vacc_ops_list():
return _C._vacc_ops_list().split(",")
def print_vacc_selective_ops():
_C._print_vacc_selective_ops()
def _vacc_shutdown():
_C._vacc_module_shutdown()
atexit.register(_vacc_shutdown)