diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index 0d8169579..c89922481 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -11,6 +11,9 @@ docker run --rm \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ + export SGL_KERNEL_ENABLE_BF16=1 && \ + export SGL_KERNEL_ENABLE_FP8=1 && \ + export SGL_KERNEL_ENABLE_SM90A=1 && \ mkdir -p /usr/lib/x86_64-linux-gnu/ && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 71952655c..184cc08c4 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,14 +1,14 @@ +import os from pathlib import Path import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -from version import __version__ root = Path(__file__).parent.resolve() -def update_wheel_platform_tag(): +def _update_wheel_platform_tag(): wheel_dir = Path("dist") if wheel_dir.exists() and wheel_dir.is_dir(): old_wheel = next(wheel_dir.glob("*.whl")) @@ -18,21 +18,25 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) -def get_cuda_version(): +def _get_cuda_version(): if torch.version.cuda: return tuple(map(int, torch.version.cuda.split("."))) return (0, 0) -def get_device_sm(): +def _get_device_sm(): if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() return major * 10 + minor return 0 -cuda_version = get_cuda_version() -sm_version = get_device_sm() +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + cutlass = root / "3rdparty" / "cutlass" flashinfer = root / "3rdparty" / "flashinfer" @@ -58,19 +62,39 @@ nvcc_flags = [ "-DFLASHINFER_ENABLE_F16", ] -if cuda_version >= (12, 0) and sm_version >= 90: - nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") +enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" +enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" +enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" +cuda_version = _get_cuda_version() +sm_version = _get_device_sm() -if sm_version >= 90: - nvcc_flags.extend( - [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - ) -if sm_version >= 80: - nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") +if torch.cuda.is_available(): + if cuda_version >= (12, 0) and sm_version >= 90: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if sm_version >= 90: + nvcc_flags.extend( + [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + ] + ) + if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") +else: + # compilation environment without GPU + if enable_sm90a: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if enable_fp8: + nvcc_flags.extend( + [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + ] + ) + if enable_bf16: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") for flag in [ "-D__CUDA_NO_HALF_OPERATORS__", @@ -82,6 +106,7 @@ for flag in [ torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) except ValueError: pass + cxx_flags = ["-O3"] libraries = ["c10", "torch", "torch_python", "cuda"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] @@ -116,11 +141,11 @@ ext_modules = [ setup( name="sgl-kernel", - version=__version__, + version=_get_version(), packages=find_packages(), package_dir={"": "src"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, ) -update_wheel_platform_tag() +_update_wheel_platform_tag()