Support FP4 gemm (1/2) (#3899)

This commit is contained in:
Trevor Morris
2025-03-24 19:50:23 -07:00
committed by GitHub
parent 22c3702e1e
commit e9f8e42318
11 changed files with 1245 additions and 5 deletions

View File

@@ -153,6 +153,10 @@ sources = [
"csrc/gemm/fp8_gemm_kernel.cu",
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"csrc/gemm/int8_gemm_kernel.cu",
"csrc/gemm/nvfp4_quant_entry.cu",
"csrc/gemm/nvfp4_quant_kernels.cu",
"csrc/gemm/nvfp4_scaled_mm_entry.cu",
"csrc/gemm/nvfp4_scaled_mm_kernels.cu",
"csrc/gemm/per_token_group_quant_8bit.cu",
"csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu",
@@ -169,6 +173,7 @@ sources = [
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
cuda_version = _get_cuda_version()
@@ -180,6 +185,7 @@ if torch.cuda.is_available():
if cuda_version >= (12, 8) and sm_version >= 100:
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
nvcc_flags.append("-DENABLE_NVFP4=1")
else:
nvcc_flags.append("-use_fast_math")
if sm_version >= 90:
@@ -188,12 +194,12 @@ if torch.cuda.is_available():
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
# compilation environment without GPU
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_sm100a:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
else:
nvcc_flags.append("-use_fast_math")
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_fp4:
nvcc_flags.append("-DENABLE_NVFP4=1")
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16: