Fix sgl-kernel compile for sm80 (#3046)
This commit is contained in:
@@ -24,6 +24,22 @@ def update_wheel_platform_tag():
|
|||||||
old_wheel.rename(new_wheel)
|
old_wheel.rename(new_wheel)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_version():
|
||||||
|
if torch.version.cuda:
|
||||||
|
return tuple(map(int, torch.version.cuda.split(".")))
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_sm():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
return major * 10 + minor
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
cuda_version = get_cuda_version()
|
||||||
|
sm_version = get_device_sm()
|
||||||
|
|
||||||
cutlass = root / "3rdparty" / "cutlass"
|
cutlass = root / "3rdparty" / "cutlass"
|
||||||
flashinfer = root / "3rdparty" / "flashinfer"
|
flashinfer = root / "3rdparty" / "flashinfer"
|
||||||
include_dirs = [
|
include_dirs = [
|
||||||
@@ -42,12 +58,15 @@ nvcc_flags = [
|
|||||||
"-gencode=arch=compute_80,code=sm_80",
|
"-gencode=arch=compute_80,code=sm_80",
|
||||||
"-gencode=arch=compute_89,code=sm_89",
|
"-gencode=arch=compute_89,code=sm_89",
|
||||||
"-gencode=arch=compute_90,code=sm_90",
|
"-gencode=arch=compute_90,code=sm_90",
|
||||||
"-gencode=arch=compute_90a,code=sm_90a",
|
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
"-use_fast_math",
|
"-use_fast_math",
|
||||||
"-DFLASHINFER_ENABLE_F16",
|
"-DFLASHINFER_ENABLE_F16",
|
||||||
"-DFLASHINFER_ENABLE_BF16",
|
"-DFLASHINFER_ENABLE_BF16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||||
|
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||||
|
|
||||||
for flag in [
|
for flag in [
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
|||||||
Reference in New Issue
Block a user