use env variable to control the build conf on the CPU build node (#3080)
This commit is contained in:
@@ -11,6 +11,9 @@ docker run --rm \
|
|||||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
|
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
|
||||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
|
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
|
||||||
export CUDA_VERSION=${CUDA_VERSION} && \
|
export CUDA_VERSION=${CUDA_VERSION} && \
|
||||||
|
export SGL_KERNEL_ENABLE_BF16=1 && \
|
||||||
|
export SGL_KERNEL_ENABLE_FP8=1 && \
|
||||||
|
export SGL_KERNEL_ENABLE_SM90A=1 && \
|
||||||
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
|
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
|
||||||
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
|
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
|
||||||
cd /sgl-kernel && \
|
cd /sgl-kernel && \
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
from version import __version__
|
|
||||||
|
|
||||||
root = Path(__file__).parent.resolve()
|
root = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
|
||||||
def update_wheel_platform_tag():
|
def _update_wheel_platform_tag():
|
||||||
wheel_dir = Path("dist")
|
wheel_dir = Path("dist")
|
||||||
if wheel_dir.exists() and wheel_dir.is_dir():
|
if wheel_dir.exists() and wheel_dir.is_dir():
|
||||||
old_wheel = next(wheel_dir.glob("*.whl"))
|
old_wheel = next(wheel_dir.glob("*.whl"))
|
||||||
@@ -18,21 +18,25 @@ def update_wheel_platform_tag():
|
|||||||
old_wheel.rename(new_wheel)
|
old_wheel.rename(new_wheel)
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_version():
|
def _get_cuda_version():
|
||||||
if torch.version.cuda:
|
if torch.version.cuda:
|
||||||
return tuple(map(int, torch.version.cuda.split(".")))
|
return tuple(map(int, torch.version.cuda.split(".")))
|
||||||
return (0, 0)
|
return (0, 0)
|
||||||
|
|
||||||
|
|
||||||
def get_device_sm():
|
def _get_device_sm():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
return major * 10 + minor
|
return major * 10 + minor
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cuda_version = get_cuda_version()
|
def _get_version():
|
||||||
sm_version = get_device_sm()
|
with open(root / "pyproject.toml") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("version"):
|
||||||
|
return line.split("=")[1].strip().strip('"')
|
||||||
|
|
||||||
|
|
||||||
cutlass = root / "3rdparty" / "cutlass"
|
cutlass = root / "3rdparty" / "cutlass"
|
||||||
flashinfer = root / "3rdparty" / "flashinfer"
|
flashinfer = root / "3rdparty" / "flashinfer"
|
||||||
@@ -58,10 +62,16 @@ nvcc_flags = [
|
|||||||
"-DFLASHINFER_ENABLE_F16",
|
"-DFLASHINFER_ENABLE_F16",
|
||||||
]
|
]
|
||||||
|
|
||||||
if cuda_version >= (12, 0) and sm_version >= 90:
|
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
|
||||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
|
||||||
|
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
|
||||||
|
cuda_version = _get_cuda_version()
|
||||||
|
sm_version = _get_device_sm()
|
||||||
|
|
||||||
if sm_version >= 90:
|
if torch.cuda.is_available():
|
||||||
|
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||||
|
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||||
|
if sm_version >= 90:
|
||||||
nvcc_flags.extend(
|
nvcc_flags.extend(
|
||||||
[
|
[
|
||||||
"-DFLASHINFER_ENABLE_FP8",
|
"-DFLASHINFER_ENABLE_FP8",
|
||||||
@@ -69,7 +79,21 @@ if sm_version >= 90:
|
|||||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if sm_version >= 80:
|
if sm_version >= 80:
|
||||||
|
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||||
|
else:
|
||||||
|
# compilation environment without GPU
|
||||||
|
if enable_sm90a:
|
||||||
|
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||||
|
if enable_fp8:
|
||||||
|
nvcc_flags.extend(
|
||||||
|
[
|
||||||
|
"-DFLASHINFER_ENABLE_FP8",
|
||||||
|
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||||
|
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if enable_bf16:
|
||||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||||
|
|
||||||
for flag in [
|
for flag in [
|
||||||
@@ -82,6 +106,7 @@ for flag in [
|
|||||||
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
|
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cxx_flags = ["-O3"]
|
cxx_flags = ["-O3"]
|
||||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
libraries = ["c10", "torch", "torch_python", "cuda"]
|
||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||||
@@ -116,11 +141,11 @@ ext_modules = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="sgl-kernel",
|
name="sgl-kernel",
|
||||||
version=__version__,
|
version=_get_version(),
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
)
|
)
|
||||||
|
|
||||||
update_wheel_platform_tag()
|
_update_wheel_platform_tag()
|
||||||
|
|||||||
Reference in New Issue
Block a user