Support FP4 gemm (1/2) (#3899)
This commit is contained in:
@@ -153,6 +153,10 @@ sources = [
|
||||
"csrc/gemm/fp8_gemm_kernel.cu",
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"csrc/gemm/int8_gemm_kernel.cu",
|
||||
"csrc/gemm/nvfp4_quant_entry.cu",
|
||||
"csrc/gemm/nvfp4_quant_kernels.cu",
|
||||
"csrc/gemm/nvfp4_scaled_mm_entry.cu",
|
||||
"csrc/gemm/nvfp4_scaled_mm_kernels.cu",
|
||||
"csrc/gemm/per_token_group_quant_8bit.cu",
|
||||
"csrc/gemm/per_token_quant_fp8.cu",
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu",
|
||||
@@ -169,6 +173,7 @@ sources = [
|
||||
|
||||
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
|
||||
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
|
||||
enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1"
|
||||
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
|
||||
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
|
||||
cuda_version = _get_cuda_version()
|
||||
@@ -180,6 +185,7 @@ if torch.cuda.is_available():
|
||||
if cuda_version >= (12, 8) and sm_version >= 100:
|
||||
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
nvcc_flags.append("-DENABLE_NVFP4=1")
|
||||
else:
|
||||
nvcc_flags.append("-use_fast_math")
|
||||
if sm_version >= 90:
|
||||
@@ -188,12 +194,12 @@ if torch.cuda.is_available():
|
||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||
else:
|
||||
# compilation environment without GPU
|
||||
if enable_sm90a:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
if enable_sm100a:
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
else:
|
||||
nvcc_flags.append("-use_fast_math")
|
||||
if enable_sm90a:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
if enable_fp4:
|
||||
nvcc_flags.append("-DENABLE_NVFP4=1")
|
||||
if enable_fp8:
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if enable_bf16:
|
||||
|
||||
Reference in New Issue
Block a user