138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
import contextlib
|
|
import functools
|
|
import io
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import sysconfig
|
|
|
|
import setuptools
|
|
|
|
|
|
# TODO: is_hip shouldn't be here
|
|
def is_hip():
|
|
import torch
|
|
return torch.version.hip is not None
|
|
|
|
|
|
def is_corex():
|
|
import torch
|
|
return hasattr(torch, "corex") and torch.corex == True
|
|
|
|
|
|
@functools.lru_cache()
|
|
def cuda_home_dirs():
|
|
loc = subprocess.check_output(["whereis", "clang++"]).decode().strip().split()[1]
|
|
default_dir = os.path.dirname(os.path.dirname(loc))
|
|
return os.getenv("CUDA_HOME", default=default_dir)
|
|
|
|
|
|
@functools.lru_cache()
|
|
def libcuda_dirs():
|
|
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
|
return [os.path.dirname(loc) for loc in locs]
|
|
|
|
|
|
@functools.lru_cache()
|
|
def rocm_path_dir():
|
|
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def quiet():
|
|
old_stdout, old_stderr = sys.stdout, sys.stderr
|
|
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
|
|
|
|
def _build(name, src, srcdir):
|
|
if is_hip():
|
|
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
|
|
hip_include_dir = os.path.join(rocm_path_dir(), "include")
|
|
else:
|
|
if is_corex():
|
|
cuda_path = cuda_home_dirs()
|
|
cu_include_dir = os.path.join(cuda_path, "include")
|
|
cuda_lib_dirs = [os.path.join(cuda_path, "lib64")]
|
|
else:
|
|
cuda_lib_dirs = libcuda_dirs()
|
|
base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
|
|
cuda_path = os.path.join(base_dir, "third_party", "cuda")
|
|
|
|
cu_include_dir = os.path.join(cuda_path, "include")
|
|
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
|
|
cuda_header = os.path.join(cu_include_dir, "cuda.h")
|
|
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
|
|
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
|
|
cu_include_dir = triton_include_dir
|
|
|
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
|
# try to avoid setuptools if possible
|
|
cc = os.environ.get("CC")
|
|
if cc is None:
|
|
# TODO: support more things here.
|
|
clang = shutil.which("clang")
|
|
gcc = shutil.which("gcc")
|
|
if is_corex():
|
|
cc = clang if clang is not None else gcc
|
|
else:
|
|
cc = gcc if gcc is not None else clang
|
|
if cc is None:
|
|
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
|
|
# This function was renamed and made public in Python 3.10
|
|
if hasattr(sysconfig, 'get_default_scheme'):
|
|
scheme = sysconfig.get_default_scheme()
|
|
else:
|
|
scheme = sysconfig._get_default_scheme()
|
|
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
|
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
|
if scheme == 'posix_local':
|
|
scheme = 'posix_prefix'
|
|
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
|
|
|
if is_hip():
|
|
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
|
else:
|
|
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
|
|
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
|
|
ret = subprocess.check_call(cc_cmd)
|
|
|
|
if ret == 0:
|
|
return so
|
|
# fallback on setuptools
|
|
extra_compile_args = []
|
|
library_dirs = cuda_lib_dirs
|
|
include_dirs = [srcdir, cu_include_dir]
|
|
libraries = ['cuda']
|
|
# extra arguments
|
|
extra_link_args = []
|
|
# create extension module
|
|
ext = setuptools.Extension(
|
|
name=name,
|
|
language='c',
|
|
sources=[src],
|
|
include_dirs=include_dirs,
|
|
extra_compile_args=extra_compile_args + ['-O3'],
|
|
extra_link_args=extra_link_args,
|
|
library_dirs=library_dirs,
|
|
libraries=libraries,
|
|
)
|
|
# build extension module
|
|
args = ['build_ext']
|
|
args.append('--build-temp=' + srcdir)
|
|
args.append('--build-lib=' + srcdir)
|
|
args.append('-q')
|
|
args = dict(
|
|
name=name,
|
|
ext_modules=[ext],
|
|
script_args=args,
|
|
)
|
|
with quiet():
|
|
setuptools.setup(**args)
|
|
return so
|