From 6fc37bd8ee3535673835712fa76a973bda0cb450 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 22 Jan 2025 16:49:08 +0800 Subject: [PATCH] Fix sgl-kernel compile for sm80 (#3046) --- sgl-kernel/setup.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index a8d9517bb..1aea485ff 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -24,6 +24,22 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) +def get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + +def get_device_sm(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + return 0 + + +cuda_version = get_cuda_version() +sm_version = get_device_sm() + cutlass = root / "3rdparty" / "cutlass" flashinfer = root / "3rdparty" / "flashinfer" include_dirs = [ @@ -42,12 +58,15 @@ nvcc_flags = [ "-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_90,code=sm_90", - "-gencode=arch=compute_90a,code=sm_90a", "-std=c++17", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_BF16", ] + +if cuda_version >= (12, 0) and sm_version >= 90: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + for flag in [ "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__",