Support Blackwell Block Scale FP8 Gemm (#4278)

This commit is contained in:
Elfie Guo
2025-03-12 14:17:11 -07:00
committed by GitHub
parent 10b544ae9b
commit 7c86671131
3 changed files with 207 additions and 2 deletions

View File

@@ -122,7 +122,6 @@ nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_VERSIONS_GENERATED",
@@ -169,12 +168,16 @@ sources = [
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
cuda_version = _get_cuda_version()
sm_version = _get_device_sm()
if torch.cuda.is_available():
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if cuda_version >= (12, 8) and sm_version >= 100:
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
if sm_version >= 90:
nvcc_flags.extend(nvcc_flags_fp8)
if sm_version >= 80:
@@ -183,6 +186,8 @@ else:
# compilation environment without GPU
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_sm100a:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16: