Files
2025-08-05 19:02:46 +08:00

138 lines
4.5 KiB
Python

import contextlib
import functools
import io
import os
import shutil
import subprocess
import sys
import sysconfig
import setuptools
# TODO: is_hip shouldn't be here
def is_hip():
import torch
return torch.version.hip is not None
def is_corex():
import torch
return hasattr(torch, "corex") and torch.corex == True
@functools.lru_cache()
def cuda_home_dirs():
loc = subprocess.check_output(["whereis", "clang++"]).decode().strip().split()[1]
default_dir = os.path.dirname(os.path.dirname(loc))
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
if is_hip():
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
hip_include_dir = os.path.join(rocm_path_dir(), "include")
else:
if is_corex():
cuda_path = cuda_home_dirs()
cu_include_dir = os.path.join(cuda_path, "include")
cuda_lib_dirs = [os.path.join(cuda_path, "lib64")]
else:
cuda_lib_dirs = libcuda_dirs()
base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
cu_include_dir = triton_include_dir
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
if is_corex():
cc = clang if clang is not None else gcc
else:
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if is_hip():
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so