176 lines
5.0 KiB
Python
176 lines
5.0 KiB
Python
import abc
|
|
import hashlib
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from ..common.build import _build
|
|
from .cache import get_cache_manager
|
|
|
|
|
|
class DriverBase(metaclass=abc.ABCMeta):
|
|
|
|
CUDA = 0
|
|
HIP = 1
|
|
|
|
@staticmethod
|
|
def third_party_dir():
|
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
# -----------------------------
|
|
# CUDA
|
|
# -----------------------------
|
|
|
|
|
|
class CudaUtils(object):
|
|
|
|
def __new__(cls):
|
|
if not hasattr(cls, 'instance'):
|
|
cls.instance = super(CudaUtils, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
|
|
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
|
cache = get_cache_manager(key)
|
|
fname = "cuda_utils.so"
|
|
cache_path = cache.get_file(fname)
|
|
if cache_path is None:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
src_path = os.path.join(tmpdir, "main.c")
|
|
with open(src_path, "w") as f:
|
|
f.write(src)
|
|
so = _build("cuda_utils", src_path, tmpdir)
|
|
cache.put(src, "main.c", binary=False)
|
|
with open(so, "rb") as f:
|
|
cache_path = cache.put(f.read(), fname, binary=True)
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
self.load_binary = mod.load_binary
|
|
self.get_device_properties = mod.get_device_properties
|
|
|
|
|
|
class CudaDriver(DriverBase):
|
|
|
|
def __new__(cls):
|
|
if not hasattr(cls, 'instance'):
|
|
cls.instance = super(CudaDriver, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
self.utils = CudaUtils()
|
|
self.backend = self.CUDA
|
|
|
|
# -----------------------------
|
|
# HIP
|
|
# -----------------------------
|
|
|
|
|
|
class HIPUtils(object):
|
|
def __new__(cls):
|
|
if not hasattr(cls, 'instance'):
|
|
cls.instance = super(HIPUtils, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
|
|
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
|
cache = get_cache_manager(key)
|
|
fname = "hip_utils.so"
|
|
cache_path = cache.get_file(fname)
|
|
if cache_path is None:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
src_path = os.path.join(tmpdir, "main.c")
|
|
with open(src_path, "w") as f:
|
|
f.write(src)
|
|
so = _build("hip_utils", src_path, tmpdir)
|
|
with open(so, "rb") as f:
|
|
cache_path = cache.put(f.read(), fname, binary=True)
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
self.load_binary = mod.load_binary
|
|
self.get_device_properties = mod.get_device_properties
|
|
|
|
|
|
class HIPDriver(DriverBase):
|
|
|
|
def __new__(cls):
|
|
if not hasattr(cls, 'instance'):
|
|
cls.instance = super(HIPDriver, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
self.utils = HIPUtils()
|
|
self.backend = self.HIP
|
|
|
|
|
|
class UnsupportedDriver(DriverBase):
|
|
|
|
def __new__(cls):
|
|
if not hasattr(cls, 'instance'):
|
|
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
self.utils = None
|
|
self.backend = None
|
|
|
|
# -----------------------------
|
|
# Driver
|
|
# -----------------------------
|
|
|
|
|
|
class LazyProxy:
|
|
def __init__(self, init_fn):
|
|
self._init_fn = init_fn
|
|
self._obj = None
|
|
|
|
def _initialize_obj(self):
|
|
if self._obj is None:
|
|
self._obj = self._init_fn()
|
|
|
|
def __getattr__(self, name):
|
|
self._initialize_obj()
|
|
return getattr(self._obj, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in ['_init_fn', '_obj']:
|
|
super().__setattr__(name, value)
|
|
else:
|
|
self._initialize_obj()
|
|
setattr(self._obj, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
self._initialize_obj()
|
|
delattr(self._obj, name)
|
|
|
|
def __repr__(self):
|
|
if self._obj is None:
|
|
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
|
return repr(self._obj)
|
|
|
|
def __str__(self):
|
|
self._initialize_obj()
|
|
return str(self._obj)
|
|
|
|
|
|
def initialize_driver():
|
|
import torch
|
|
if torch.version.hip is not None:
|
|
return HIPDriver()
|
|
elif torch.cuda.is_available():
|
|
return CudaDriver()
|
|
else:
|
|
return UnsupportedDriver()
|
|
|
|
|
|
driver = LazyProxy(initialize_driver)
|