feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)

This commit is contained in:
Yineng Zhang
2025-01-23 00:39:38 +08:00
committed by GitHub
parent b2bd8f444c
commit bf669606eb
6 changed files with 131 additions and 12 deletions

View File

@@ -62,12 +62,22 @@ nvcc_flags = [
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
]
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if sm_version >= 90:
nvcc_flags.extend(
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",