Fix sgl-kernel compile for sm80 (#3046)

This commit is contained in:
Ke Bao
2025-01-22 16:49:08 +08:00
committed by GitHub
parent 3d8f1c9bcf
commit 6fc37bd8ee

View File

@@ -24,6 +24,22 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def get_device_sm():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor
return 0
cuda_version = get_cuda_version()
sm_version = get_device_sm()
cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer"
include_dirs = [
@@ -42,12 +58,15 @@ nvcc_flags = [
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-gencode=arch=compute_90a,code=sm_90a",
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
]
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",