Support Blackwell Block Scale FP8 Gemm (#4278)
This commit is contained in:
@@ -122,7 +122,6 @@ nvcc_flags = [
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||
"-DCUTLASS_VERSIONS_GENERATED",
|
||||
@@ -169,12 +168,16 @@ sources = [
|
||||
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
|
||||
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
|
||||
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
|
||||
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
|
||||
cuda_version = _get_cuda_version()
|
||||
sm_version = _get_device_sm()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
if cuda_version >= (12, 8) and sm_version >= 100:
|
||||
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
if sm_version >= 90:
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if sm_version >= 80:
|
||||
@@ -183,6 +186,8 @@ else:
|
||||
# compilation environment without GPU
|
||||
if enable_sm90a:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
if enable_sm100a:
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
if enable_fp8:
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if enable_bf16:
|
||||
|
||||
Reference in New Issue
Block a user